首页 > 人工智能(Artificial Intelligence) > 生成对抗网络

生成对抗网络(GAN)的训练

GAN 的基本思想就是一个最小最大定理,当两个玩家(D 和 G)彼此竞争时(零和博弈),双方都假设对方采取最优的步骤而自己也以最优的策略应对(最小最大策略),那么结果就已经预先确定了,玩家无法改变它(纳什均衡)。

对抗样本

对抗样本(adversarial example),它是指经过精心计算得到的用于误导分类器的样本。

图像分类器本质上是高维空间的一个复杂的决策边界。当然涉及到图像分类的时候,由于是高维空间而不是简单的两维或者三维空间,我们无法画出这个边界出来。但是我们可以肯定的是,训练完成后,分类器是无法泛化到所有数据上,除非我们的训练集包含了分类类别的所有数据,但实际上我们做不到。而做不到泛化到所有数据的分类器,其实就会过拟合训练集的数据,这也就是我们可以利用的一点。

我们给图片添加一个非常接近于 0 的随机噪声,这可以通过控制噪声的 L2 范数来实现。L2 范数可以看做是一个向量的长度,这里有个诀窍就是图片的像素越多,即图片尺寸越大,其平均 L2 范数也就越大。因此,当添加的噪声的范数足够低,那么视觉上你不会觉得这张图片有什么不同。但是,在向量空间上,添加噪声后的图片和原始图片已经有很大的距离了。除了这种简单的添加随机噪声,还可以通过图像变形的方式,使得新图像和原始图像视觉上一样的情况下,让分类器得到有很高置信度的错误分类结果。这种过程也被称为对抗攻击(adversarial attack),这种生成方式的简单性也是给 GAN 提供了解释。

模型训练

生成模型与对抗模型可以说是完全独立的两个模型,训练这样的两个模型的大方法就是:单独交替迭代训练

假设现在生成网络模型已经有了(当然可能不是最好的生成网络),那么给一堆随机数组,就会得到一堆假的样本集(因为不是最终的生成模型,那么现在生成网络可能就处于劣势,导致生成的样本就不咋地,可能很容易就被判别网络判别出来了说这货是假冒的),但是先不管这个,假设我们现在有了这样的假样本集,真样本集一直都有,现在我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,很明显这里我们就已经默认真样本集所有的类标签都为1,而假样本集的所有类标签都为0. 

对于生成网络目的,是生成尽可能逼真的样本。那么原始的生成网络生成的样本你怎么知道它真不真呢?就是送到判别网络中,所以在训练生成网络的时候,我们需要联合判别网络一起才能达到训练的目的。如果我们把刚才的判别网络串接在生成网络的后面,这样我们就知道真假了,也就有了误差了。所以对于生成网络的训练其实是对生成-判别网络串接的训练。好了那么现在来分析一下样本,原始的噪声数组Z我们有,也就是生成了假样本我们有,此时很关键的一点来了,我们要把这些假样本的标签都设置为1,也就是认为这些假样本在生成网络训练的时候是真样本。

在完成生成网络训练好,那么我们是不是可以根据目前新的生成网络再对先前的那些噪声Z生成新的假样本了,没错,并且训练后的假样本应该是更真了才对。然后又有了新的真假样本集(其实是新的假样本集),这样又可以重复上述过程了。我们把这个过程称作为单独交替训练。我们可以实现定义一个迭代次数,交替迭代到一定次数后停止即可。这个时候我们再去看一看噪声Z生成的假样本会发现,原来它已经很真了。

GAN强大之处在于能自动学习原始真实样本集的数据分布,不管这个分布多么的复杂,只要训练的足够好就可以学出来。GAN的生成模型最后可以通过噪声生成一个完整的真实数据(比如人脸),说明生成模型掌握了从随机噪声到人脸数据的分布规律。GAN一开始并不知道这个规律是什么样,也就是说GAN是通过一次次训练后学习到的真实样本集的数据分布。

论文中的一张图来解释: 

v2-aab535a56ee0fabaa3d52998d1baf616_r.jpg

(黑色的线表示数据x的实际分布,绿色的线表示数据的生成分布,蓝色的线表示生成的数据对应在判别器中的分布效果)

对于图a,D还刚开始训练,本身分类的能力还很有限,有波动,但是初步区分实际数据和生成数据还是可以的。图b,D训练得比较好了,可以很明显的区分出生成数据。然后对于图c:绿色的线与黑色的线的偏移,蓝色的线下降了,也就是生成数据的概率下降了。那么,由于绿色的线的目标是提升概率,因此就会往蓝色线高的方向移动。那么随着训练的持续,由于G网络的提升,G也反过来影响D的分布。假设固定G网络不动,训练D,那么训练到最优,$D^*_g(x) = p_{data}(x)/(p_{data}(x)+p_{g}(x))$。因此,随着$p_g(x)$趋近于$p_{data}(x),D^*_g(x)$会趋近于0.5,也就是到图d。而我们的目标就是希望绿色的线能够趋近于黑色的线,也就是让生成的数据分布与实际分布相同。图d符合我们最终想要的训练结果。到这里,G网络和D网络就处于纳什均衡状态,无法再进一步更新了。

这张图表明的是GAN的生成网络如何一步步从均匀分布学习到正太分布的。原始数据x服从正太分布,这个过程你也没告诉生成网络说你得用正太分布来学习,但是生成网络学习到了。假设你改一下x的分布,不管什么分布,生成网络可能也能学到。这就是GAN可以自动学习真实数据的分布的强大之处。

原论文的整体算法:

v2-ce9f2c59b74e1970692fafe7d713fc99_hd.jpg

训练的技巧和方法:

首先 G 和 D 是同步训练,但两者训练次数不一样,通常是 D 网络训练 k 次后,G 训练一次。主要原因是 GAN 刚开始训练时候会很不稳定;

D 的训练是同时输入真实数据和生成数据来计算 loss,而不是采用交叉熵(cross entropy)分开计算。不采用 cross entropy 的原因是这会让 D(G(z)) 变为 0,导致没有梯度提供给 G 更新,而现在 GAN 的做法是会收敛到 0.5;

实际训练的时候,作者是采用 -log(D(G(z))) 来代替 log(1-D(G(z))) ,这是希望在训练初始就可以加大梯度信息,这是因为初始阶段 D 的分类能力会远大于 G 生成足够真实数据的能力,但这种修改也将让整个 GAN 不再是一个完美的零和博弈。

基本流程如下:

初始化判别器D的参数 $\theta_{d}$ 和生成器G的参数$ \theta_g$ 。

从真实样本中采样 m 个样本$ { x^1, x^2, ... x^m  }$ ,从先验分布噪声中采样 m 个噪声样本 ${ z^1, z^2, ...,z^m  }$ 并通过生成器获取 m  个生成样本$ { \widetilde x^1, \widetilde x^2, ..., \widetilde x^m } $。固定生成器G,训练判别器D尽可能好地准确判别真实样本和生成样本,尽可能大地区分正确样本和生成的样本。

循环k次更新判别器之后,使用较小的学习率来更新一次生成器的参数,训练生成器使其尽可能能够减小生成样本与真实样本之间的差距,也相当于尽量使得判别器判别错误。

多次更新迭代之后,最终理想情况是使得判别器判别不出样本来自于生成器的输出还是真实的输出。亦即最终样本判别概率均为0.5。

Tips: 之所以要训练k次判别器,再训练生成器,是因为要先拥有一个好的判别器,使得能够教好地区分出真实样本和生成样本之后,才好更为准确地对生成器进行更新。

关闭
感谢您的支持,我会继续努力!
扫码打赏,建议金额1-10元


提醒:打赏金额将直接进入对方账号,无法退款,请您谨慎操作。