GANs很难?这篇文章教你50行代码搞定(PyTorch)

简介:
本文来自AI新媒体量子位(QbitAI)

640?wx_fmt=png&wxfrom=5&wx_lazy=1


2014年,Ian Goodfellow和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,正式把生成对抗网络(GANs)介绍给全世界。通过把计算图和博弈论创新性的结合起来,GANs有能力让两个互相对抗的模型通过反向传播共同训练。

模型中有两个相互对抗的角色,我们分别称为GD,简单解释如下:G是一个生成器,它试图通过学习真实数据集R,来创建逼真的假数据;D鉴别器,从R和G处获得数据并标记差异。

Goodfellow有个很好的比喻:G是一个造假团队,试图造出跟真画一样的赝品;D是鉴定专家,试图找出真画和赝品的差异。当然在GANs的设定里,G是一群永远见不到真画的造假团队,他们能够获得的反馈只有D的鉴定意见。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

在理想情况下,D和G都会随着时间的推移变得更好,直到G变成一个造假大师,最终让D无法区分出真画和赝品。实际上,Goodfellow已经表明G能够对原始数据集进行无监督学习,并且找到这些数据的低维表达方式。


这么厉害的技术,代码怎么也得一大堆吧?

并不是。使用刚刚发布的PyTorch,实际上可以只用不到50行代码,就能创建一个GAN。我们需要考虑的组件只有下面五个:

 R:原始的真实数据集

 I:作为熵源输入生成器的随机噪声

 G:尝试复制/模仿原始数据集的生成器

 D:尝试分辨G输出的鉴别器

 一个训练循环:教G造假,再教D来鉴别……

1)R: 我们将从最简单的R,一个钟形曲线开始。这个函数以平均值和标准偏差为参数,然后返回一个函数。在我们的示例代码中,使用了平均值4.0和标准差1.25。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

2)I: 输入生成器的噪声也是随机的,但是为了增加点难度,我们使用了一个均匀分布,而不是正态分布。这意味着模型G不能简单地通过移动/缩放复制R,而必须以非线性的方式重塑数据。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

3)G: 生成器是一个标准的前馈图,包含两个隐藏层,三个线性映射。在这里,我们使用了ELU(指数线性单位)。G将从I获得均匀分布的数据样本,并以某种方式模仿来自R的正态分布样本。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

4)D: 鉴别器与生成器G的代码非常相似,都是有两个隐藏层和三个线性映射的前馈图。它将从R或G获取样本,并输出介于0和1之间的单个标量,0和1分别表示“假”和“真”。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

5)训练循环 最后,训练循环在两种模式之间交替:首先,用带有准确标签的真实数据和假数据来训练D;然后,训练G来愚弄D。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

即使你从没用过PyTorch,也大致能看出发生了什么。在上图标为绿色的第一部分,我们将不同类型的数据输入D,并对D的猜测结果和实际的标签进行评判。这一步是“正向”的,然后我们用“反向”来计算梯度,并用它来更新d_optimizer step()调用的D参数。

上面,我们用到了G,但没有训练它。

在标为红色的下半部分中,我们对G做了同样的事情,注意:我们还会通过D来运行G的输出,相当于给了造假者一个侦探练习。但是在这一步中,我们不会对D进行优化或更改,因为我们不希望D学到错误的标签。因此,我们只调用g_optimizer.step()。

就这些啦,还有一些其他的样本代码,但是针对GAN的只有这五个组件。


对D和G进行几千轮训练之后,我们能得到什么?鉴别器D优化得很快,而G一开始优化得比较慢,不过,一旦到达了特定水平,G就开始迅速成长。

两万轮训练过后,G的输出的平均值超过4.0,但随后回到一个相当稳定,正确的范围(如左图)。同样,标准偏差最初在错误的方向下降,但随后上升到所要求的1.25范围(右图),与R相当。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

所以,基本的统计最终与R相当,那么高阶矩如何呢?分布的形状是否正确?毕竟,你当然可以有一个平均值为4.0、标准差为1.25的均匀分布,但这不会真正与R相匹配。让我们看看G形成的最终分布。

640?wx_fmt=png&wxfrom=5&wx_lazy=1

还不错。左尾比右边稍微长了一点,但是我们可以说,它的偏斜和峰态符合原始的高斯函数。

G几乎完美还原了R的原始分布,而D独自在角落徘徊,无法分清真伪。这正是我们想要的结果。用不到50行的代码,就能实现。

本文作者:Dev Nag
原文发布时间:2017-02-13
相关文章
|
10月前
|
索引
【Pytorch--代码技巧】各种论文代码常见技巧
博主在阅读论文原代码的时候常常看见一些没有见过的代码技巧,特此将这些内容进行汇总
118 0
|
11月前
|
存储 算法 计算机视觉
【项目实践】从零开始学习Deep SORT+YOLO V3进行多目标跟踪(附注释项目代码)(二)
【项目实践】从零开始学习Deep SORT+YOLO V3进行多目标跟踪(附注释项目代码)(二)
139 0
|
11月前
|
机器学习/深度学习 算法 决策智能
【项目实践】从零开始学习Deep SORT+YOLO V3进行多目标跟踪(附注释项目代码)(一)
【项目实践】从零开始学习Deep SORT+YOLO V3进行多目标跟踪(附注释项目代码)(一)
216 0
|
机器学习/深度学习 PyTorch 算法框架/工具
从零开始学Pytorch(十二)之凸优化
从零开始学Pytorch(十二)之凸优化
从零开始学Pytorch(十二)之凸优化
|
算法 计算机视觉
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(三)
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(三)
206 0
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(三)
|
机器学习/深度学习 算法 图计算
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(一)
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(一)
107 0
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(一)
|
存储 编解码 算法
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(二)
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(二)
227 0
熟练掌握CV中最基础的概念:图像特征,看这篇万字的长文就够了(二)
|
搜索推荐 大数据 PyTorch
推荐模型复现(一):熟悉Torch-RecHub框架与使用
Torch-RecHub是一个轻量级的pytorch推荐模型框架
568 0
推荐模型复现(一):熟悉Torch-RecHub框架与使用
|
机器学习/深度学习 PyTorch TensorFlow
深度学习基础之三分钟轻松搞明白tensor到底是个啥
再不入坑就晚了,深度神经网络概念大整理,最简单的神经网络是什么样子?
427 0
深度学习基础之三分钟轻松搞明白tensor到底是个啥
|
存储 数据可视化 PyTorch
学懂 ONNX,PyTorch 模型部署再也不怕!
在把 PyTorch 模型转换成 ONNX 模型时,我们往往只需要轻松地调用一句 torch.onnx.export 就行了。这个函数的接口看上去简单,但它在使用上还有着诸多的“潜规则”。在这篇教程中,我们会详细介绍 PyTorch 模型转 ONNX 模型的原理及注意事项。除此之外,我们还会介绍 PyTorch 与 ONNX 的算子对应关系,以教会大家如何处理 PyTorch 模型转换时可能会遇到的算子支持问题。
3690 0
学懂 ONNX,PyTorch 模型部署再也不怕!