关键词

深度学习:GAN 对抗网络原理详细解析(零基础必看)

深度学习:GAN 对抗网络原理详细解析(零基础必看)

什么是GAN网络

GAN的全称是Generative adversarial network,中文翻译过来就是对抗式神经网络。对抗神经网络其实是两个网络的组合,可以理解为一个网络生成模拟数据(生成网络Generator),另一个网络判断生成的数据是真实的还是模拟的(判别网络Discriminator)。生成网络要不断优化自己生成的数据让判别网络判断不出来,判别网络也要优化自己让自己判断得更准确。二者关系形成对抗,因此叫对抗神经网络。

GAN的意义及应用场景

意义
我们了解了GAN的定义,那么学习或者说是了解GAN有什么意义?GAN又有哪些应用场景呢?

有人说GAN强大之处在于可以自动的学习原始真实样本集的数据分布,不管这个分布多么的复杂,只要训练的足够好就可以学出来。针对这一点,感觉有必要好好理解一下为什么别人会这么说。
我们知道,传统的机器学习方法,我们一般都会定义一个什么模型让数据去学习。比如说假设我们知道原始数据属于高斯分布呀,只是不知道高斯分布的参数,这个时候我们定义高斯分布,然后利用数据去学习高斯分布的参数得到我们最终的模型。再比如说我们定义一个分类器,比如SVM,然后强行让数据进行东变西变,进行各种高维映射,最后可以变成一个简单的分布,SVM可以很轻易的进行二分类分开,其实SVM已经放松了这种映射关系了,但是也是给了一个模型,这个模型就是核映射(什么径向基函数等等),说白了其实也好像是你事先知道让数据该怎么映射一样,只是核映射的参数可以学习罢了。
所有的这些方法都在直接或者间接的告诉数据你该怎么映射一样,只是不同的映射方法能力不一样。那么我们再来看看GAN,生成模型最后可以通过噪声生成一个完整的真实数据(比如人脸),说明生成模型已经掌握了从随机噪声到人脸数据的分布规律了,有了这个规律,想生成人脸还不容易。然而这个规律我们开始知道吗?显然不知道,如果让你说从随机噪声到人脸应该服从什么分布,你不可能知道。这是一层层映射之后组合起来的非常复杂的分布映射规律。然而GAN的机制可以学习到,也就是说GAN学习到了真实样本集的数据分布。
还有人说GAN强大之处在于可以自动的定义潜在损失函数。 什么意思呢,这应该说的是判别网络可以自动学习到一个好的判别方法,其实就是等效的理解为可以学习到好的损失函数,来比较好或者不好的判别出来结果。虽然大的loss函数还是我们人为定义的,基本上对于多数GAN也都这么定义就可以了,但是判别网络潜在学习到的损失函数隐藏在网络之中,不同的问题这个函数就不一样,所以说可以自动学习这个潜在的损失函数。

应用场景
GAN的应用场景包括:
1,数据生成,主要指图像生成。常用的有DCGAN WGAN,BEGAN;
2,GAN本身也是一种无监督学习的典范,因此它在无监督学习,半监督学习领域都有广泛的应用;
3,不仅在生成领域,GAN在分类领域也占有一席之地,简单来说,就是替换判别器为一个分类器,做多分类任务,而生成器仍然做生成任务,辅助分类器训练;
4,GAN可以和强化学习结合,目前一个比较好的例子就是seq-GAN;
5,目前比较有意思的应用就是GAN用在图像风格迁移,图像降噪修复,图像超分辨率了,都有比较好的结果;
6,图像数据增强。

GAN的基本网络结构


这就是GAN网络的基本形式,我们可以发现其实就像文章开头所述,GAN是由两个网络(生成网络和识别网络)组合而成。
现在我们仔细分析一下上面这个网络:
生成网络(Generator):输入为随机数据,输出为生成数据(通常是图像)。通常这个网络选用最普通的多层随机网络即可,网络太深容易引起梯度消失或者梯度爆炸。下图是生成网络的黑盒效果示意图。图中我们输入一个一维数组,通过Generator网络生成一张图片。我们通过调整输入的数据或者是网络参数可以改变输出的图片效果。

识别网络(Discriminator):现在,我们把生成网络生成的数据称为假数据,对应的,来自真实数据集的数据称为真数据。判别网络输入为数据(这里指代真实图像和生成图像),输出一个判别概率。需注意的是,这里判别的是图像的真伪,而非图像的类别。输入一个图片后,我们并不需要确认这张图片是个啥,而是判别图像到底来自于真实数据集,还是生成网络的胡乱合成。所以输出一个一维条件概率(伯努利分布的概率参数)就好了。网络实现同样可用最基本的多层神经网络。

如何优化网络(定义损失)

既然我们知道了对抗网络最后做的就是一个二分类问题,那么问题来了?如何优化这个网络或者说我们如何定义损失函数?
其实很简单,GAN有两个网络,那么自然就有两个损失函数。
生成网络的损失函数
LG=H(1,D(G(z)))L_{G}=H(1,D(G(z)))LG=H(1,D(G(z)))
上式中,GGG代表生成网络,DDD代表判别网络,HHH代表交叉熵,zzz是输入随机数据。D(G(z))D(G(z))D(G(z))是对假数据的判断概率,1代表数据绝对真实,0代表数据绝对虚假。H(1,D(G(z)))H(1,D(G(z)))H(1,D(G(z)))代表判断结果与1的距离。如果读者对交叉熵损失函数不了解,可以参考我的另一篇博文啥也不会照样看懂交叉熵损失函数

识别网络的损失函数
LD=H(1,D(x))+H(0,D(G(z)))L_{D}=H(1,D(x))+H(0,D(G(z)))LD=H(1,D(x))+H(0,D(G(z)))
上式中,xxx是真实数据,这里要注意的是H(1,D(x))H(1,D(x))H(1,D(x))代表真实数据与1的距离,H(0,D(G(z)))H(0,D(G(z)))H(0,D(G(z)))代表生成数据与0的距离。很显然,识别网络要想取得良好的效果,那么就要做到,在它眼里,真实数据就是真实数据,生成数据就是虚假数据(即真实数据与1的距离小,生成数据与0的距离小)。

训练过程
(该段部分内容参考自博客GAN神经网络分析
GAN对抗网络的训练过程通常是两个网络单独且交替训练:先训练识别网络,再训练生成网络,再训练识别网络,如此反复,直到达到纳什均衡。

假设现在生成网络模型已经有了(当然可能不是最好的生成网络),那么给一堆随机数组,就会得到一堆假的样本集(因为不是最终的生成模型,那么现在生成网络可能就处于劣势,导致生成的样本就不咋地,可能很容易就被判别网络判别出来了说这货是假冒的),但是先不管这个,假设我们现在有了这样的假样本集,真样本集一直都有,现在我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,很明显这里我们就已经默认真样本集所有的类标签都为1,而假样本集的所有类标签都为0.

有人会说,在真样本集里面的人脸中,可能张三人脸和李四人脸不一样呀,对于这个问题我们需要理解的是,我们现在的任务是什么,我们是想分样本真假,而不是分真样本中那个是张三label、那个是李四label。况且我们也知道,原始真样本的label我们是不知道的。回过头来,我们现在有了真样本集以及它们的label(都是1)、假样本集以及它们的label(都是0),这样单就判别网络来说,此时问题就变成了一个再简单不过的有监督的二分类问题了,直接送到神经网络模型中训练就完事了。假设训练完了,下面我们来看生成网络。

对于生成网络,想想我们的目的,是生成尽可能逼真的样本。那么原始的生成网络生成的样本你怎么知道它真不真呢?就是送到判别网络中,所以在训练生成网络的时候,我们需要联合判别网络一起才能达到训练的目的。什么意思?就是如果我们单单只用生成网络,那么想想我们怎么去训练?误差来源在哪里?细想一下没有,但是如果我们把刚才的判别网络串接在生成网络的后面,这样我们就知道真假了,也就有了误差了。所以对于生成网络的训练其实是对生成-判别网络串接的训练。好了那么现在来分析一下样本,原始的噪声数组Z我们有,也就是生成了假样本我们有,此时很关键的一点来了,我们要把这些假样本的标签都设置为1,也就是认为这些假样本在生成网络训练的时候是真样本。

为什么要这样呢?我们想想,是不是这样才能起到迷惑判别器的目的,也才能使得生成的假样本逐渐逼近为正样本。好了,重新顺一下思路,现在对于生成网络的训练,我们有了样本集(只有假样本集,没有真样本集),有了对应的label(全为1),是不是就可以训练了?有人会问,这样只有一类样本,训练啥呀?谁说一类样本就不能训练了?只要有误差就行(生成网络的数据后面给识别器看,看最终结果如果loss值很低,则生成器成功欺骗了识别器(把假数据当成和label一样也是1了),如果loss很大(label上尽管是1,但是识别器还是预测为0,识别器是真的认出来了),说明生成器还需提升)。还有人说,你这样一训练,判别网络的网络参数不是也跟着变吗?没错,这很关键,所以在训练这个串接的网络的时候,一个很重要的操作就是不要判别网络的参数发生变化,也就是不让它参数发生更新,只是把误差一直传,传到生成网络那块后更新生成网络的参数。这样就完成了生成网络的训练了。

在完成生成网络训练好,那么我们是不是可以根据目前新的生成网络再对先前的那些噪声Z生成新的假样本了,没错,并且训练后的假样本应该是更真了才对。然后又有了新的真假样本集(其实是新的假样本集),接着真假样本集又都给识别器训练,这样又可以重复上述过程了。我们把这个过程称作为单独交替训练。我们可以实现定义一个迭代次数,交替迭代到一定次数后停止即可。这个时候我们再去看一看噪声Z生成的假样本会发现,原来它已经很真了。

看完了这个过程是不是感觉GAN的设计真的很巧妙,个人觉得最值得称赞的地方可能在于这种假样本在训练过程中的真假变换,这也是博弈得以进行的关键之处。假样本集在训练识别器时候label为0,是为方便计算loss,检验有多少成功欺骗了识别器,被识别器预测为1了。假样本集在训练生成器时候label为1,也是为方便计算loss,检验有多少被识别器发现了,来提升识别器的性能。我们最终目的是得到一个如火纯情的造假者的生成器!识别器是辅助工具罢了。但是识别器也不能太差劲了,得2个同时提升性能,才能达到一个我们理想的生成器。关键在于交替训练的时候要平衡的交替,不能一方太强,否则2者一起训练提升就无法继续了

GAN网络的局限性

如此神奇且强大的对抗网络也有它力所不逮的地方,那就是它无法处理文本数据。

文本数据相比较图片数据来说是离散的,因为对于文本来说,通常需要将一个词映射为一个高维的向量,最终预测的输出是一个one-hot向量,假设softmax的输出是(0.2, 0.3, 0.1,0.2,0.15,0.05)那么变为onehot是(0,1,0,0,0,0),如果softmax输出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot仍然是(0, 1, 0, 0, 0, 0),所以对于生成器来说,GGG输出了不同的结果但是D给出了同样的判别结果,并不能将梯度更新信息很好的传递到GGG中去,所以DDD最终输出的判别没有意义。

一个小栗子

关于GAN对抗网络的一个实际应用就是垃圾邮件的处理问题。假设有一个叫Gary的营销人员试图骗过David的垃圾邮件分类器来发送垃圾邮件。Gary希望能尽可能地发送多的垃圾邮件,David希望尽可能少的垃圾邮件通过。理想情况下会达到纳什均衡,尽管我们谁都不想收到垃圾邮件。具体可以参考知乎上的一篇文章如何形象又有趣的讲解对抗神经网络GAN是什么?

本文链接:http://task.lmcjl.com/news/12158.html

展开阅读全文