Ian J. Goodfellow等人于2014年在论文Generative Adversarial Nets中提出了一个通过对抗过程估计生成模型的新框架。框架中同时训练两个模型:一个生成模型(generative model)G,用来捕获数据分布;一个判别模型(discriminative model)D,用来估计样本来自于训练数据的概率。G的训练过程是将D错误的概率最大化。可以证明在任意函数G和D的空间中,存在唯一的解决方案,使得G重现训练数据分布,而D=0.5。
生成对抗网络(GAN,Generative Adversarial Networks)的基本原理很简单:假设有两个网络,生成网络G和判别网络D。生成网络G接受一个随机的噪声z并生成图片,记为G(z);判别网络D的作用是判别一张图片x是否真实,对于输入x,D(x)是x为真实图片的概率。在训练过程中, 生成器努力让生成的图片更加真实从而使得判别器无法辨别图像的真假,而D的目标就是尽量把分辨出真实图片和生成网络G产出的图片,这个过程就类似于二人博弈,G和D构成了一个动态的“博弈过程”。随着时间的推移,生成器和判别器在不断地进行对抗,最终两个网络达到一个动态平衡:生成器生成的图像G(z)接近于真实图像分布,而判别器识别不出真假图像,即D(G(z))=0.5。最后,我们就可以得到一个生成网络G,用来生成图片。
对于GAN更加直观的理解:生成模型可以被看做是一个伪造团队,试图生产假币并且在不被发现的情况下使用, 而判别模型则类似于警察,尝试检查是否为假币。伪造团队的目的是生产出警察识别不出的假币,而警察则是想更加精确地识别出假币,因此在这个游戏中,两个团队因为各自目的而不断改进它们的方法直到伪造团队生产的假币警察分辨不出来。
上面讲述生成对抗网络的基本原理, 为了能够更深此理解GAN,下面我们使用GAN来生成MNIST数据集。
import tensorflow as tf import numpy as np import os from tensorflow.examples.tutorials.mnist import input_data from matplotlib import pyplot as plt BATCH_SIZE = 64 UNITS_SIZE = 128 LEARNING_RATE = 0.001 EPOCH = 300 SMOOTH = 0.1 mnist = input_data.read_data_sets('/mnist_data/', one_hot=True) # 生成模型 def generatorModel(noise_img, units_size, out_size, alpha=0.01): with tf.variable_scope('generator'): FC = tf.layers.dense(noise_img, units_size) reLu = tf.nn.leaky_relu(FC, alpha) drop = tf.layers.dropout(reLu, rate=0.2) logits = tf.layers.dense(drop, out_size) outputs = tf.tanh(logits) return logits, outputs # 判别模型 def discriminatorModel(images, units_size, alpha=0.01, reuse=False): with tf.variable_scope('discriminator', reuse=reuse): FC = tf.layers.dense(images, units_size) reLu = tf.nn.leaky_relu(FC, alpha) logits = tf.layers.dense(reLu, 1) outputs = tf.sigmoid(logits) return logits, outputs # 损失函数 """ 判别器的目的是: 1. 对于真实图片,D要为其打上标签1 2. 对于生成图片,D要为其打上标签0 生成器的目的是:对于生成的图片,G希望D打上标签1 """ def loss_function(real_logits, fake_logits, smooth): # 生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量 G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)*(1-smooth))) # 判别器识别生成器产出的图片,希望识别出来的标签为0 fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits))) # 判别器判别真实图片,希望判别出来的标签为1 real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)*(1-smooth))) # 判别器总loss D_loss = tf.add(fake_loss, real_loss) return G_loss, fake_loss, real_loss, D_loss # 优化器 def optimizer(G_loss, D_loss, learning_rate): train_var = tf.trainable_variables() G_var = [var for var in train_var if var.name.startswith('generator')] D_var = [var for var in train_var if var.name.startswith('discriminator')] # 因为GAN中一共训练了两个网络,所以分别对G和D进行优化 G_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(G_loss, var_list=G_var) D_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(D_loss, var_list=D_var) return G_optimizer, D_optimizer # 训练 def train(mnist): image_size = mnist.train.images[0].shape[0] real_images = tf.placeholder(tf.float32, [None, image_size]) fake_images = tf.placeholder(tf.float32, [None, image_size]) #调用生成模型生成图像G_output G_logits, G_output = generatorModel(fake_images, UNITS_SIZE, image_size) # D对真实图像的判别 real_logits, real_output = discriminatorModel(real_images, UNITS_SIZE) # D对G生成图像的判别 fake_logits, fake_output = discriminatorModel(G_output, UNITS_SIZE, reuse=True) # 计算损失函数 G_loss, real_loss, fake_loss, D_loss = loss_function(real_logits, fake_logits, SMOOTH) # 优化 G_optimizer, D_optimizer = optimizer(G_loss, D_loss, LEARNING_RATE) saver = tf.train.Saver() step = 0 with tf.Session() as session: session.run(tf.global_variables_initializer()) for epoch in range(EPOCH): for batch_i in range(mnist.train.num_examples // BATCH_SIZE): batch_image, _ = mnist.train.next_batch(BATCH_SIZE) # 对图像像素进行scale,tanh的输出结果为(-1,1) batch_image = batch_image * 2 -1 # 生成模型的输入噪声 noise_image = np.random.uniform(-1, 1, size=(BATCH_SIZE, image_size)) # session.run(G_optimizer, feed_dict={fake_images:noise_image}) session.run(D_optimizer, feed_dict={real_images: batch_image, fake_images: noise_image}) step = step + 1 # 判别器D的损失 loss_D = session.run(D_loss, feed_dict={real_images: batch_image, fake_images:noise_image}) # D对真实图片 loss_real =session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image}) # D对生成图片 loss_fake = session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image}) # 生成模型G的损失 loss_G = session.run(G_loss, feed_dict={fake_images: noise_image}) print('epoch:', epoch, 'loss_D:', loss_D, ' loss_real', loss_real, ' loss_fake', loss_fake, ' loss_G', loss_G) model_path = os.getcwd() + os.sep + "mnist.model" saver.save(session, model_path, global_step=step) def main(argv=None): train(mnist) if __name__ == '__main__': tf.app.run()
上述是训练模型,下面是测试模型,依据训练模型训练的参数。generatorImage函数生成手写字体图片, 在这里显示了25张图片。 生成图像如下图1所示,还能够大略猜出生成的图片中的数字。
import tensorflow as tf import numpy as np from matplotlib import pyplot as plt import pickle import mnist_GAN UNITS_SIZE = mnist_GAN.UNITS_SIZE def generatorImage(image_size): sample_images = tf.placeholder(tf.float32, [None, image_size]) G_logits, G_output = mnist_GAN.generatorModel(sample_images, UNITS_SIZE, image_size) saver = tf.train.Saver() with tf.Session() as session: session.run(tf.global_variables_initializer()) saver.restore(session, tf.train.latest_checkpoint('.')) sample_noise = np.random.uniform(-1, 1, size=(25, image_size)) samples = session.run(G_output, feed_dict={sample_images:sample_noise}) with open('samples.pkl', 'wb') as f: pickle.dump(samples, f) def show(): with open('samples.pkl', 'rb') as f: samples = pickle.load(f) fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True) for ax, image in zip(axes.flatten(), samples): ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) ax.imshow(image.reshape((28, 28)), cmap='Greys_r') plt.show() def main(argv=None): image_size = mnist_GAN.mnist.train.images[0].shape[0] generatorImage(image_size) show() if __name__ == '__main__': tf.app.run()
图1. 生成图片展示
上述基于MNIST数据集构造了一个简单的GAN模型,对于生成模型和判别模型,仅仅使用了简单的神经网络,对于图像的处理,卷积神经网络更胜一筹,如果将生成模型和判别模型改为深度卷积网络,那么生成更加清晰的图片。 而且目前也有各种GAN变体,后续慢慢整理。
参考博客:https://blog.csdn.net/sinat_33741547/article/details/77751035
原文链接:https://blog.csdn.net/weixin_42111770/article/details/81449449
本文链接:http://task.lmcjl.com/news/12533.html