深度学习界明星:生成对抗网络与Improving GAN

本文涉及的产品
简介: 2014年,深度学习三巨头之一IanGoodfellow提出了生成对抗网络(Generative Adversarial Networks, GANs)这一概念,刚开始并没有引起轰动,直到2016年,学界、业界对它的兴趣如“井喷”一样爆发,多篇重磅文章陆续发表。

2014年,深度学习三巨头之一IanGoodfellow提出了生成对抗网络(Generative Adversarial Networks, GANs)这一概念,刚开始并没有引起轰动,直到2016年,学界、业界对它的兴趣如“井喷”一样爆发,多篇重磅文章陆续发表。2016年12月NIPS大会上,Goodfellow做了关于GANs的专题报告,使得GANs成为了当今最热门的研究领域之一,本文将介绍如今深度学习界的明星——生成对抗网络。

1何为生成对抗网络

生成对抗网络,根据它的名字,可以推断这个网络由两部分组成:第一部分是生成,第二部分是对抗。这个网络的第一部分是生成模型,就像之前介绍的自动编码器的解码部分;第二部分是对抗模型,严格来说它是一个判断真假图片的判别器。生成对抗网络最大的创新在此,这也是生成对抗网络与自动编码器最大的区别。简单来说,生成对抗网络就是让两个网络相互竞争,通过生成网络来生成假的数据,对抗网络通过判别器判别真伪,最后希望生成网络生成的数据能够以假乱真骗过判别器。过程如图1所示。

图1 生成对抗网络生成数据过程

下面依次介绍生成模型和对抗模型。

1. 生成模型

首先看看生成模型,前一节自动编码器其实已经给出了一般的生成模型。

在生成对抗网络中,不再是将图片输入编码器得到隐含向量然后生成图片,而是随机初始化一个隐含向量,根据变分自动编码器的特点,初始化一个正态分布的隐含向量,通过类似解码的过程,将它映射到一个更高的维度,最后生成一个与输入数据相似的数据,这就是假的图片。这时自动编码器是通过对比两张图片之间每个像素点的差异计算损失函数的,而生成对抗网络会通过对抗过程来计算出这个损失函数,如图2所示。

图2 生成模型

2. 对抗模型

重点来介绍对抗过程,这个过程是生成对抗网络相对于之前的生成模型如自动编码器等最大的创新。

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。

这跟原图片的label 没有关系,不管原图片到底是一个多少类别的图片,它们都统一称为真的图片,输出的label 是1,则表示是真实的;而生成图片的label 是0,则表示是假的。

在训练的时候,先训练判别器,将假的数据和真的数据都输入给判别模型,这个时候优化这个判别模型,希望它能够正确地判断出真的数据和假的数据,这样就能够得到一个比较好的判别器。

然后开始训练生成器,希望它生成的假的数据能够骗过现在这个比较好的判别器。

具体做法就是将判别器的参数固定,通过反向传播优化生成器的参数,希望生成器得到的数据在经过判别器之后得到的结果能尽可能地接近1,这时只需要调整一下损失函数就可以了,之前在优化判别器的时候损失函数是让假的数据尽可能接近0,而现在训练生成器的损失函数是让假的数据尽可能接近1。

这其实就是一个简单的二分类问题,这个问题可以用前面介绍过的很多方法去处理,比如Logistic 回归、多层感知器、卷积神经网络、循环神经网络等。

上面是生成对抗网络的简单解释,可以通过代码更清晰地展示整个过程。

跟自动编码器一样,先使用简单的多层感知器来实现:

class discriminator(nn.Module):
  def __init__(self):
    super(discriminator, self).__init__()
    self.dis = nn.Sequential(
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1),
      nn.Sigmoid()
    )

  def forward(self, x):
    x = self.dis(x)
    return x

上面是判别器的结构,中间使用了斜率设为0.2 的LeakyReLU 激活函数,最后需要使用nn.Sigmoid() 将结果映射到0 s 1 之间概率进行真假的二分类。这里之所以用LeakyReLU 激活函数而不使用ReLU 激活函数,是因为经过实验,LeakyReLU 的表现更好。

class generator(nn.Module):
  def __init__(self, input_size):
    super(generator, self).__init__()
    self.gen = nn.Sequential(
      nn.Linear(input_size, 256),
      nn.ReLU(True),
      nn.Linear(256, 256),
      nn.ReLU(True),
      nn.Linear(256, 784),
      nn.Tanh()
    )

  def forward(self, x):
      x = self.gen(x)
      return x

这就是生成器的结构,跟自动编码器中的解码器是类似的,最后需要使用nn.Tanh(),将数据分布到-1 ~1 之间,这是因为输入的图片会规范化到-1 ~1之间。

接着需要定义损失函数和优化函数:

criterion = nn.BCELoss() # Binary Cross Entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

这里使用二分类的损失函数nn.BCELoss(),使用Adam 优化函数,学习率设置为0.0003。

接着是最为重要的训练过程,这个过程分为两个部分:一个是判别器的训练,一个是生成器的训练。

首先来看看判别器的训练。

img = img.view(num_img, -1)
real_img = Variable(img).cuda()
real_label = Variable(torch.ones(num_img)).cuda()
fake_label = Variable(torch.zeros(num_img)).cuda()

# compute loss of real_img
real_out = D(real_img)
d_loss_real = criterion(real_out, real_label)
real_scores = real_out

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()
fake_img = G(z)
fake_out = D(fake_img)
d_loss_fake = criterion(fake_out, fake_label)
fake_scores = fake_out

# bp and optimize
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

开始需要自己创建label,真实的数据是1,生成的假的数据是0,然后将真实的数据输入判别器得到loss,将假的数据输入判别器得到loss,将这两个loss 加起来得到总的loss,然后反向传播去更新参数就能够得到一个优化好的判别器。

接下来是生成模型的训练:

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到随机噪声
fake_img = G(z) # 生成假的图片
output = D(fake_img) # 经过判别器得到结果
g_loss = criterion(output, real_label) # 得到假的图片与真实图片label的loss

# bp and optimize
g_optimizer.zero_grad() # 归0梯度
g_loss.backward() # 反向传播
g_optimizer.step() # 更新生成网络的参数

一个随机隐含向量通过生成网络得到了一个假的数据,然后希望假的数据经过判别模型后尽可能和真实label 接近,通过g_loss = criterion(output, real_label)实现,然后反向传播去优化生成器的参数,在这个过程中,判别器的参数不再发生变化,否则生成器永远无法骗过优化的判别器。

除了使用简单的多层感知器外,也可以在生成模型和对抗模型中使用更加复杂的卷积神经网络,定义十分简单。

class discriminator(nn.Module):
  def __init__(self):
    super(discriminator, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28
      nn.LeakyReLU(0.2, True),
      nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14
      )
    self.conv2 = nn.Sequential(
      nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14
      nn.LeakyReLU(0.2, True),
      nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7
    )
    self.fc = nn.Sequential(
      nn.Linear(64*7*7, 1024),
      nn.LeakyReLU(0.2, True),
      nn.Linear(1024, 1),
      nn.Sigmoid()
    )

  def forward(self, x):
    '''
    x: batch, width, height, channel=1
    '''
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

class generator(nn.Module):
  def __init__(self, input_size, num_feature):
    super(generator, self).__init__()
    self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56
    self.br = nn.Sequential(
      nn.BatchNorm2d(1),
      nn.ReLU(True)
    )
    self.downsample1 = nn.Sequential(
      nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56
      nn.BatchNorm2d(50),
      nn.ReLU(True)
    )
    self.downsample2 = nn.Sequential(
      nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56
      nn.BatchNorm2d(25),
      nn.ReLU(True)
    )
    self.downsample3 = nn.Sequential(
      nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28
      nn.Tanh()
    )

  def forward(self, x):
    x = self.fc(x)
    x = x.view(x.size(0), 1, 56, 56)
    x = self.br(x)
    x = self.downsample1(x)
    x = self.downsample2(x)
    x = self.downsample3(x)
    return x

图3 左边是多层感知器的生成对抗网络,右边是卷积生成对抗网络,右边的图片比左边的图片噪声明显更少。在卷积神经网络里引入了批标准化(Batchnormalization)来稳定训练,同时使用了LeakyReLU 和平均池化来进行训练。生成对抗网络的训练其实是很困难的,因为这是两个对偶网络在相互学习,所以需要增加一些训练技巧才能使训练更加稳定。

图3生成对抗网络对比结果

以上介绍了生成对抗网络的简单原理和训练流程,但是对生成对抗网络而言,它其实并没有真正地学习到它要表示的物体,通过对抗的过程,它只是生成了一张尽可能真的图片,这就意味着没办法决定用哪种噪声能够生成想要的图片,除非把初始分布都试一遍。所以在生成对抗网络提出之后,有很多基于标准生成对抗网络的变式来解决各种各样的问题。

2 Improving GAN

这一节将介绍改善的生成对抗网络,因为生成对抗网络存在很多问题,所以人们研究能否通过改善网络结构或者损害函数来解决这些问题。

1 Wasserstein GAN

Wasserstein GAN 是GAN 的一种变式,我们知道GAN 的训练是非常麻烦的,需要很多训练技巧,而且在不同的数据集上,由于数据的分布会发生变化,也需要重新调整参数,不仅需要小心地平衡生成器和判别器的训练进程,同时生成的样本还缺乏多样性。除此之外最大的问题是没办法衡量这个生成器到底好不好,因为没办法通过判别器的loss 去判断这个事情。虽然DC GAN 依靠对生成器和判别器的结构进行枚举,最终找到了一个比较好的网络设置,但还是没有从根本上解决训练的问题。

WGAN 的出现,彻底解决了下面这些难点:

(1)彻底解决了训练不稳定的问题,不再需要设计参数去平衡判别器和生成器;

(2)基本解决了collapse mode 的问题,确保了生成样本的多样性;

(3)训练中有一个向交叉熵、准确率的数值指标来衡量训练的进程,数值越小代表GAN 训练得越好,同时也就代表着生成的图片质量越高;

(4)不需要精心设计网络结构,用简单的多层感知器就能够取得比较好的效果。

下面先介绍为什么GAN 会有这些缺点,然后解释WGAN是通过什么办法解决这些问题的。

① GAN 的局限性

根据之前介绍的,有下面的式子(1):

从式(1)我们知道原始的GAN 是通过最优判别器下的JS Divergence 来衡量两种分布之间的差异的,而且最优判别器下JS Divergence 越小,就说明两种分布越接近,但是JS Divergence 有一个严重的问题,那就是如果两种分布完全没有重叠部分,或者说重叠部分可忽略,那么JS Divergence 将恒等于常数log2。换句话说,就算两种分布很接近,但是只要它们没有重叠,那么JS Divergence 就是一个常数,这就使得网络没办法通过这个损失函数去学习,因为它没办法知道它是否做得好,这就会导致梯度消失,同时这也使得我们没有办法衡量这两种分布到底有多靠近。

而真实分布与生成的分布没有重叠部分的概率有多大呢?其实是非常大的,直观来讲,真实分布是一个高维分布,而生成的分布来自于一个低维分布,所以其实很有可能生成分布和真实分布之间就没有重叠的部分。除此之外,不可能真正去计算两个分布,只能近似取样,所以也导致了两种分布没有重叠部分。如果判别器训练得太好,那么生成的分布和原来分布基本没有重叠部分,这就导致了梯度消失;如果判别器训练得不好,这样生成器的梯度又不准,就会出现错误的优化方向。如果要使得GAN 能够完美地收敛,那么需要判别器的训练不好也不坏,而这个度是很难把握的,况且这还依赖数据的分布等条件,所以GAN 才这么难训练。

②Wasserstein 距离

既然GAN 存在的问题都是由于JS Divergence 引起的,那么能不能换一种度量方式去衡量两种分布之间的差异,而不使用JS Divergence?答案是肯定的,这就是WGAN中提出的解决办法。

首先介绍一种新的度量方式去度量两种分布之间的差异——Wasserstein 距离,也称为Earth Mover 距离,定义如下:

看上去可能比较复杂,数学解释如下:对于两种分布Pr 和Pg,它们的联合分布是II(Pr,Pg),换句话说II(Pr,Pg) 中每一个联合分布的边缘分布就是Pr 或者Pg。那么对每一个联合分布而言,从里面取样x 和y,并计算x 和y 的距离,然后取遍所有的x 和y 计算一下期望,接着取这些期望里面最小的作为W 距离的定义。

如果上面的解释不够清楚,也可以通俗地解释,因为它还有一个别名叫Earth mover距离,也就是推土机距离,这是什么意思呢?可以把两种分布想象成两堆土,然后想想如何用推土机将一种分布变成另外一种分布的样子,会有很多种移动方案,里面最小消耗的那种方案就是最优的方案,也就是这个距离的定义。

W 距离与JS Divergence 相比有什么好处呢?最大的好处就是不管两种分布是否有重叠,它都是连续变换的而不是突变的,可以用下面这个例子来说明一下,如图4所示。

图4 W 距离例子

通过上面这个演示可以发现,虽然两种分布更接近,但JS Divergence 仍然是log2,W 距离就能够连续而有效地衡量两种分布之间的差异。

③WGAN

W 距离有很好的优越性,把它拿来作为两种分布的度量优化生成器,但是W 距离里面有一个是没办法求解的。作者Martin 在论文附录里面通过定理将这个问题转变成了一个新的问题,有着如下形式:

这里引入了一个新的概念——Lipschitz 连续。如果函数f 满足Lipschitz 连续条件,那么它就满足下面的式子:

我们不希望函数的变化太快,希望函数f 变化能比较平缓。

那么可以将上面的式子改成GAN:

也就是说构建一个神经网络D 作为判别器,希望D 输出的变化比较平缓,在实际计算中限制D 中的参数大小不超过某个范围,这样就使得关于输入的样本,D 的输出变化基本不会超过某个范围,所以就能够基本满足Lipschitz 连续条件。

所以最后构造一个判别器D,满足:

尽可能取到最大,同时D 还要满足Lipschitz 连续条件,得到的L 可以近似为真实分布和生成分布的Wasserstein 距离。原始的GAN 做的是二分类的任务,也就是对于真假图片进行二分类,而WGAN 做的是回归问题,相当于近似拟合Wasserstein 距离。

最后优化生成器的时候希望最小化L,这时候需要满足Lipschitz 连续条件,所以需要做权重的裁剪,由于W 距离的优越性,不再需要担心梯度消失的问题,这样就能够得到WGAN 的整个训练过程。

总结一下,WGAN 与原始GAN 相比,只改了以下四点:

(1)判别器最后一层去掉sigmoid;

(2)生成器和判别器的loss 不取log;

(3)每次更新判别器的参数之后把它们的绝对值裁剪到不超过一个固定常数的数;

(4)不要用基于动量的优化算法(比如momentuem 和Adam),推荐使用RMSProp。

前三点都是从理论分析得到的结果,第(4)点是作者从实验中发现的。对于WGAN,论文作者做了不少实验,得到了几个结论:第一,WGAN 如果使用类似DCGAN 的结构,那么和DCGAN 生成的图片差不多,但是WGAN 的优势就在于不用DCGAN 的结构,也能生成效果比较好的图片,但是把DCGAN 的Batch Normalization 拿掉的话,DCGAN 就不能生成图片了;第二,WGAN 和原始的GAN 都是用多层全连接网络的话,WGAN 生成的图片质量会变得差一些,但是原始的GAN 不仅质量很差,还有多样性不足的问题。

2 Improving WGAN

WGAN 的提出成功地解决了GAN 的很多问题,最后需要满足一阶Lipschitz 连续性条件,所以在训练的时候加了一个限制——权重裁剪。

然而权重的裁剪只是一种简单的做法,不是最好的做法,所以随后有人提出了一些新的办法来解决这个问题。

首先提出一个定理:一个可微函数如果满足1 阶Lipschitz 连续,等价于它的梯度范数处小于1。用式子来表示就是:

有了这个定理,就能够近似地这样去表达W 距离:

不需要在整个分布上都满足Lipschitz 条件,只需要沿着一些直线上的点满足这些,结果就已经很好了,同时在实际中采用的策略也不是取max,因为不希望太小,所以做的是最小化,最后改进的WGAN 就是:

改进后的WGAN 和改进前的WGAN 相比,训练更加稳定,生成的图片效果也更好。

以上内容节选自《深度学习入门之PyTorch》,点此链接可在博文视点官网查看此书。
                 图片描述

  想及时获得更多精彩文章,可在微信中搜索“博文视点”或者扫描下方二维码并关注。
                    图片描述

相关实践学习
基于函数计算一键部署掌上游戏机
本场景介绍如何使用阿里云计算服务命令快速搭建一个掌上游戏机。
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
相关文章
|
1月前
|
机器学习/深度学习 数据采集 人工智能
m基于深度学习网络的手势识别系统matlab仿真,包含GUI界面
m基于深度学习网络的手势识别系统matlab仿真,包含GUI界面
38 0
|
1月前
|
机器学习/深度学习 算法 计算机视觉
基于yolov2深度学习网络的火焰烟雾检测系统matlab仿真
基于yolov2深度学习网络的火焰烟雾检测系统matlab仿真
|
1月前
|
机器学习/深度学习 算法 计算机视觉
m基于深度学习网络的性别识别系统matlab仿真,带GUI界面
m基于深度学习网络的性别识别系统matlab仿真,带GUI界面
29 2
|
2月前
|
机器学习/深度学习 监控 算法
m基于深度学习网络的活体人脸和视频人脸识别系统matlab仿真,带GUI界面
m基于深度学习网络的活体人脸和视频人脸识别系统matlab仿真,带GUI界面
38 0
|
1月前
|
机器学习/深度学习 算法 计算机视觉
基于yolov2深度学习网络的视频手部检测算法matlab仿真
基于yolov2深度学习网络的视频手部检测算法matlab仿真
|
1月前
|
机器学习/深度学习 人工智能 TensorFlow
人工智能与图像识别:基于深度学习的卷积神经网络
人工智能与图像识别:基于深度学习的卷积神经网络
34 0
|
7天前
|
机器学习/深度学习 算法 PyTorch
【动手学深度学习】深入浅出深度学习之线性神经网络
【动手学深度学习】深入浅出深度学习之线性神经网络
52 9
|
1月前
|
机器学习/深度学习 并行计算 算法
m基于深度学习网络的瓜果种类识别系统matlab仿真,带GUI界面
m基于深度学习网络的瓜果种类识别系统matlab仿真,带GUI界面
31 0
|
8天前
|
机器学习/深度学习 自然语言处理 算法
|
1月前
|
机器学习/深度学习 运维 算法

热门文章

最新文章