GAIN
别人缺陷
基于深度学习算法包括去噪自动编码器(DAE)和生成对抗网络(GAN)。DAE在训练时需要使用完整的数据集。也有一些DAE方法不需要完整数据集,但是只用了能够观察到的数据来进行训练。有方法使用深度卷积GAN来做图像修复,也是需要完整的数据来训练鉴别器D。
问题主要出在需要完整数据集,没有的话就只使用数据中完整的那一部分。
创新点
- 没有完整数据集,也可以使用
- 设计了一个提示矩阵,为鉴别器提供额外的信息(后续说明),这种提示确保了生成器根据真实的底层数据分布生成样本。
修复问题定义
d维输入(数据向量 data vector)X,X=X_1×X_2…×X_d=(X_1,…,X_d)X=X_1×X_2…×X_d=(X_1,…,X_d),其分布为P(X)。掩码矩阵(mask vector) M=(M_1,…,M_d)M=(M_1,…,M_d),元素的值随机为0或1。再定义一个X~tilde{X}。M可以表示为X的哪些分量是已被观察的(即非缺失的),且从X~tilde{X}中可以恢复M。
通过对数据的分布进行建模,而不是仅仅对期望建模,这样的话可以进行多次绘制,也就是多次估算,使得我们能够捕获估算值的不确定性。大意就是对数据分布建模可以进行多次估算确实的值。
模型设计
生成器G
X~tilde{X}, M 和噪音Z(d维噪声)作为输入,输出为Xˉbar{X}。定义G:
Xˉ=G(X~,M,(1−M)⊙Z)X^=M⊙X~+(1−M)⊙Xˉbar{X}=G(tilde{X},M,(1-M)⊙Z) hat{X}=M⊙tilde{X}+(1-M)⊙bar{X}
X^hat{X}是修复后的完整数据。
生成器就是简单的全连接层
鉴别器D
标准的GAN,鉴别器是判断生成器数据是真还是假,然而这里的生成器的数据是真与假的结合。
现在这里的鉴别器D识别的不是整个向量是真是假,而是区分一个向量里哪些元素是真,哪些是假,相当于预测出掩码矩阵。即如果鉴别器很强大,那么预测出的掩码矩阵就是我们原先定义的M。
为什么这样设计鉴别器:如果生成器很强大,生成了满足原先分布的数据,鉴别器完全鉴别不出来,则鉴别器鉴别出的掩码矩阵就是全1。我们希望鉴别器能够鉴别出哪些是生成的,如果鉴别出来,就会越来越趋近于M。
鉴别器也是简单的全连接层。
提示矩阵Hint
揭示了原始数据中缺失部分的某些信息,让D更加关注它所提示的部分,同时也逼迫G生成更加接近真实的数据用来填补。提示矩阵是基于掩码矩阵来自定义的,提示矩阵和G生成的数据共同作为输入。为什么可以产生这种效果,详见论文推导。
训练
训练D以最大化正确预测M的概率(训练D来让D的输出接近M),训练G来最小化D预测M的概率(训练G来让D的输出远离M) 。通俗来讲,D为了判别生成的数据是假,所以对于缺失位置,尽可能判断为0,即逐渐趋向于M。定义V(D, G):
先训练D,设D的输出为D_prob(注意,输入D中的数据new_x,未缺失位置的数据使用原值,缺失位置的数据使用生成值)。D的损失如下:
D_loss = -torch.mean(m * torch.log(D_prob + 1e-8) + (1 - m) * torch.log(1. - D_prob + 1e-8))
如果D判别未缺失位置的数接近1,则torch.mean(m * torch.log(D_prob + 1e-8)
越大,那么D_loss越小。
如果D判别缺失位置的数据接近1,则(1 - m) * torch.log(1. - D_prob + 1e-8))
越大,那么D_loss越大。
再训练G,设G的输出为G_sample。G的损失如下
G_loss = -torch.mean((1 - m) * torch.log(D_prob + 1e-8)) +
alpha * torch.mean((m * new_x - m * G_sample) ** 2) / torch.mean(m)
torch.mean((1 - m) * torch.log(D_prob + 1e-8))
上面解释过。
torch.mean((m * new_x - m * G_sample) ** 2) / torch.mean(m)
表示未缺失位置的MSE
总体而言
- G让生成的未缺失数据逼近于真实数据分布
- D是被出哪些是缺失数据,哪些是未缺失数据
- 没有使用到缺失位置的loss,因为实际场景下数据是缺失的,确实位置是无法计算loss的
pytorch实现代码:github.com/RadishVeget…
一些参考: