GAN网络通俗解释(图画版)

简介: 最通俗的GAN网络介绍!

在本教程中,你将了解什么是生成敌对网络(GAN),并且在整个过程中不涉及负责的数学细节。之后,你还将学习如何编写一个可以创建数字的简单GAN!

42b88334a7ca0218a51234b6a5dc845eec862baa

什么是GAN(插画版介绍)

理解GAN的最简单方法是通过一个简单的比喻:

假设有一家商店它们从顾客那里购买某些种类的葡萄酒,用于以后再销售。

b5454da980a2ff7cbf7defda5ad243d731973c7d

然而,有些恶意的顾客为了获得金钱而出售假酒。在这种情况下,店主必须能够区分假酒和正品葡萄酒。

8df9a7b2f6df0dc56cd1c2e9c03df574bf7b3f0e

你可以想象,最初,伪造者在尝试出售假酒时可能会犯很多错误,并且店主很容易认定该酒不是真的。由于这些失败,伪造者会继续尝试使用不同的技术来模拟真正的葡萄酒,最终才有可能成功。现在,伪造者知道某些技术已经超过了店主的认识假酒的能力,他可以开始进一步生产基于这些技术的假酒。

同时,店主可能会从其他店主或葡萄酒专家那里得到一些反馈,说明他拥有的一些葡萄酒不是原装的。这意味着店主必须改善他是如何确定葡萄酒是伪造的还是真实的。伪造者的目标是制造与真实葡萄酒无法区分的葡萄酒,而店主的目标是准确地分辨葡萄酒是否真实。

这种来回的竞争博弈就是GAN网络背后的主要思想。

生成敌对网络的组成部分

使用上面的例子,我们可以想出一个GAN的体系结构。

cc9b2ea37fbc7240d5ec41bcc350e26c33034e97

GAN网络中有两个主要组件:生成器和鉴别器。这个例子中的店主被称为鉴别器网络,并且通常是卷积神经网络(因为GAN主要用于图像任务),其主要功能是判断图像是真实的概率。

伪造者被称为生成网络,并且通常也是卷积神经网络(具有解卷积层)。该网络需要一些噪声矢量并输出图像。在训练生成网络时,它会学习图像的哪些区域进行改进/更改,以便鉴别器将难以将其生成的图像与真实图像区分开来。

生成网络不断生成更接近真实图像的图像,而辨别网络试图确定真实图像和假图像之间的差异。最终的目标是建立一个可生成与真实图像无法区分的图像的生成网络。

一个简单的Keras生成对抗网络

现在你已经了解了GAN是什么以及它们的主要组成部分,现在我们可以开始编写一个非常简单的代码。本教程将使用Keras,如果你不熟悉此Python库,则应在继续之前阅读翻译小组其他文章。本教程是基于这里开发的非常酷且易于理解的GAN。

你需要做的第一件事是通过以下方式安装以下软件包pip

- keras
- matplotlib
- tensorflow
- tqdm

你将matplotlib用于绘制tensorflow——Keras后端库,并用tqdm为每个时期(迭代)显示一个奇特的进度条。

下一步是创建一个Python脚本。在这个脚本中,你首先需要导入你将要使用的所有模块和函数,在使用它们时将给出每个解释。

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers

你现在想要设置一些变量值:

# Let Keras know that we are using tensorflow as our backend engine
os.environ["KERAS_BACKEND"] = "tensorflow"
# To make sure that we can reproduce the experiment and get the same results
np.random.seed(10)
# The dimension of our random noise vector.
random_dim = 100
 

在开始构建鉴别器和生成器之前,你应该首先收集并预处理数据。你将使用现在最流行的MNIST数据集,该数据集具有一组从0到9范围内的单个数字的图像。

d232ec9c62087f6060aac961678344187aca3140

def load_minst_data():
    # load the data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    # normalize our inputs to be in the range[-1, 1] 
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
    # 784 columns per row
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)

请注意mnist.load_data()这个函数是Keras的一部分,它允许你轻松将MNIST数据集导入你的工作区。

现在,你可以创建你的生成器和鉴别器网络。你可以为这两个网络使用Adam优化器。对于生成器和鉴别器,你将创建一个带有三个隐藏层的神经网络,激活函数为Leaky Relu。你还应该为鉴别器添加Drop-out图层,以提高其对未见图像的鲁棒性。

def get_optimizer():
    return Adam(lr=0.0002, beta_1=0.5) 
def get_generator(optimizer):
    generator = Sequential()
    generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(784, activation='tanh'))
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator
def get_discriminator(optimizer):
    discriminator = Sequential()
    discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    discriminator.add(Dense(1, activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator

终于到了将生成器和鉴别器放在一起的时候了!

def get_gan_network(discriminator, random_dim, generator, optimizer):
    # We initially set trainable to False since we only want to train either the 
    # generator or discriminator at a time
    discriminator.trainable = False
    # gan input (noise) will be 100-dimensional vectors
    gan_input = Input(shape=(random_dim,))
    # the output of the generator (an image)
    x = generator(gan_input)
    # get the output of the discriminator (probability if the image is real or not)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=optimizer)
    return gan

为了保持整个过程的完整性,你可以创建一个功能,每20个纪元保存你生成的图像。由于这不是本教程的核心,所以你不需要完全理解该功能。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, random_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

你现在已经编码了大部分网络,剩下的就是训练这个网络,并看看你创建的图像。

def train(epochs=1, batch_size=128):
    # Get the training and testing data
    x_train, y_train, x_test, y_test = load_minst_data()
    # Split the training data into batches of size 128
    batch_count = x_train.shape[0] / batch_size
    # Build our GAN netowrk
    adam = get_optimizer()
    generator = get_generator(adam)
    discriminator = get_discriminator(adam)
    gan = get_gan_network(discriminator, random_dim, generator, adam)

    for e in xrange(1, epochs+1):
        print '-'*15, 'Epoch %d' % e, '-'*15
        for _ in tqdm(xrange(batch_count)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

            # Generate fake MNIST images
            generated_images = generator.predict(noise)
            X = np.concatenate([image_batch, generated_images])
            # Labels for generated and real data
            y_dis = np.zeros(2*batch_size)
            # One-sided label smoothing
            y_dis[:batch_size] = 0.9
            # Train discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_dis)
            # Train generator
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y_gen)

        if e == 1 or e % 20 == 0:
            plot_generated_images(e, generator)

if __name__ == '__main__':
    train(400, 128)

训练400个纪元后,你可以查看生成的图像。查看第一个纪元后产生的图像,可以看到它没有任何真实的结构,在40个纪元后查看图像,数字开始成形,最后,400个纪元后产生的图像显示出清晰的数字,尽管是一对夫妇仍然无法辨认。

9ac8bb04a0b0adb56a11b654561e69480081ac38

1纪元(左)后的结果40个纪元后(中)的结果400个时代后的结果(右)

此代码在CPU上每个纪元大约需要2分钟,这是选择此代码的主要原因。你可以尝试使用更多的纪元,并通过向生成器和鉴别器添加更多(和不同的)图层。但是,当使用更复杂和更深的体系结构时,如果仅使用CPU,则运行时也会增加。

结论

恭喜,你已经完成了本教程的最后部分,你已经以直观的方式学习生成敌对网络(GAN)的基础知识!

数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧!

本文由@阿里云云栖社区组织翻译。

文章原标题《demystifying-generative-adversarial-networks》,

译者:虎说八道,审校:袁虎。

文章为简译,更为详细的内容,请查看原文 

相关文章
|
3月前
|
安全 API Android开发
Android网络和数据交互: 解释Retrofit库的作用。
Android网络和数据交互: 解释Retrofit库的作用。
38 0
|
3月前
|
Android开发 开发者
Android网络和数据交互: 请解释Android中的AsyncTask的作用。
Android网络和数据交互: 请解释Android中的AsyncTask的作用。
21 0
|
3月前
|
JSON Java Android开发
Android网络和数据交互: 请解释Android中的JSON解析库,如Gson。
Android网络和数据交互: 请解释Android中的JSON解析库,如Gson。
24 0
|
4月前
|
机器学习/深度学习 编解码 TensorFlow
【Keras+计算机视觉+Tensorflow】生成对抗神经网络中DCGAN、CycleGAN网络的讲解(图文解释 超详细)
【Keras+计算机视觉+Tensorflow】生成对抗神经网络中DCGAN、CycleGAN网络的讲解(图文解释 超详细)
49 0
|
4月前
|
机器学习/深度学习 自动驾驶 算法
【计算机视觉+自动驾驶】二、多任务深度学习网络并联式、级联式构建详细讲解(图像解释 超详细必看)
【计算机视觉+自动驾驶】二、多任务深度学习网络并联式、级联式构建详细讲解(图像解释 超详细必看)
69 1
|
4月前
|
机器学习/深度学习 传感器 自动驾驶
【计算机视觉】一、多任务深度学习网络的概念及在自动驾驶中的应用讲解(图文解释 超详细)
【计算机视觉】一、多任务深度学习网络的概念及在自动驾驶中的应用讲解(图文解释 超详细)
44 0
|
4月前
|
机器学习/深度学习 自然语言处理 算法
【深度学习】常用算法生成对抗网络、自编码网络、多层感知机、反向传播等讲解(图文解释 超详细)
【深度学习】常用算法生成对抗网络、自编码网络、多层感知机、反向传播等讲解(图文解释 超详细)
41 0
|
8月前
|
机器学习/深度学习
【RL-GAN-Net】强化学习控制GAN网络,用于实时点云形状的补全。
【RL-GAN-Net】强化学习控制GAN网络,用于实时点云形状的补全。
141 0
|
10月前
|
机器学习/深度学习 编解码 人工智能
深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决
深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决
深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决
|
11月前
|
机器学习/深度学习 人工智能 数据可视化
阿斯利康团队用具有域适应性的可解释双线性注意网络改进了药物靶标预测
阿斯利康团队用具有域适应性的可解释双线性注意网络改进了药物靶标预测
171 0

热门文章

最新文章