【李沐】十分钟从 PyTorch 转 MXNet

简介: PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。MXNet通过ndarray和 gluon模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法。

PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。例如 Caffe2 最近就并入了 PyTorch。

可能大家不是特别知道的是,MXNet 通过 ndarray 和 gluon 模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法

89e0e8d5de21311740959c69f9ae2fe0258d52af

安装

PyTorch 默认使用 conda 来进行安装,例如

03192dec910f50e049d5fecb3109e8b09f6cdf9b

而 MXNet 更常用的是使用 pip。我们这里使用了 --pre 来安装 nightly 版本

83d9786fd8b0f46bc693765c60c2e0544ec118a7

多维矩阵

对于多维矩阵,PyTorch 沿用了 Torch 的风格称之为 tensor,MXNet 则追随了 NumPy 的称呼 ndarray。下面我们创建一个两维矩阵,其中每个元素初始化成 1。然后每个元素加 1 后打印。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

b472e4f3dada3709d53edf6608ab47f322089ca2

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

28436e80f4887a6ef0ba21b7dedafad23e823410

忽略包名的不一样的话,这里主要的区别是 MXNet 的形状传入参数跟 NumPy 一样需要用括号括起来。

模型训练

下面我们看一个稍微复杂点的例子。这里我们使用一个多层感知机(MLP)来在 MINST 这个数据集上训练一个模型。我们将其分成 4 小块来方便对比。

读取数据

这里我们下载 MNIST 数据集并载入到内存,这样我们之后可以一个一个读取批量。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

a3657b3dbcca68c3b62521edd3f0dd3082a15389

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d50f77b5a334e2bd2072180a29c72c9a74ed18dc

这里的主要区别是 MXNet 使用 transform_first 来表明数据变化是作用在读到的批量的第一个元素,既 MNIST 图片,而不是第二个标号元素。

定义模型

下面我们定义一个只有一个单隐层的 MLP 。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

01818faab6a9ae66f6daae74be169b64c894f344

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

8e95eac5ee43682f1ca882e782de98079be13d9e

我们使用了 Sequential 容器来把层串起来构造神经网络。这里 MXNet 跟 PyTorch 的主要区别是:

8481c8f592b7f349aa84a1de5c171db681516edf 不需要指定输入大小,这个系统会在后面自动推理得到
8481c8f592b7f349aa84a1de5c171db681516edf 全连接和卷积层可以指定激活函数
8481c8f592b7f349aa84a1de5c171db681516edf需要创建一个  name_scope  的域来给每一层附上一个独一无二的名字,这个在之后读写模型时需要
8481c8f592b7f349aa84a1de5c171db681516edf 我们需要显示调用模型初始化函数。


大家知道 Sequential 下只能神经网络只能逐一执行每个层。PyTorch 可以继承 nn.Module 来自定义 forward 如何执行。同样,MXNet 可以继承 nn.Block 来达到类似的效果。

损失函数和优化算法

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

483451f3193b8143e4fe7c180da0a03baff4fc71

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d126effd55a80c7df82aa0e96cb0f5cf7f1c5785

这里我们使用交叉熵函数和最简单随机梯度下降并使用固定学习率 0.1

训练

最后我们实现训练算法,并附上了输出结果。注意到每次我们会使用不同的权重和数据读取顺序,所以每次结果可能不一样。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch

37274d74bd5215f00a2a585afaea92d1eb809284

8481c8f592b7f349aa84a1de5c171db681516edfMXNet

fa0addb6b0a7d824307feb70fcd9eae4ea9e209a

MXNet 跟 PyTorch 的不同主要在下面这几点:

8481c8f592b7f349aa84a1de5c171db681516edf不需要将输入放进  Variable , 但需要将计算放在  mx.autograd.record()  里使得后面可以对其求导
8481c8f592b7f349aa84a1de5c171db681516edf 不需要每次梯度清 0,因为新梯度是写进去,而不是累加
8481c8f592b7f349aa84a1de5c171db681516edf step  的时候 MXNet 需要给定批量大小
8481c8f592b7f349aa84a1de5c171db681516edf需要调用  asscalar()  来将多维数组变成标量。
8481c8f592b7f349aa84a1de5c171db681516edf 这个样例里 MXNet 比 PyTorch 快两倍。当然大家对待这样的比较要谨慎。

下一步

8481c8f592b7f349aa84a1de5c171db681516edf 更详细的 MXNet 的教程:http://zh.gluon.ai/

8481c8f592b7f349aa84a1de5c171db681516edf欢迎给我们留言哪些 PyTorch 的方便之处你希望 MXNet 应该也可以有



原文发布时间为:2018-04-3

本文作者:李沐

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”微信公众号

原文链接:【李沐】十分钟从 PyTorch 转 MXNet

相关文章
|
4月前
|
机器学习/深度学习 TensorFlow API
tensorflow从头再学1
tensorflow从头再学1
22 1
|
11月前
|
机器学习/深度学习 编译器 TensorFlow
瞎聊深度学习——TensorFlow的基本应用
瞎聊深度学习——TensorFlow的基本应用
|
机器学习/深度学习 数据挖掘 PyTorch
# 【深度学习】:《PyTorch入门到项目实战》第九天:Dropout实现
上一章我们介绍了L2正则化和权重衰退,在深度学习中,还有一个很实用的方法——Dropout,能够减少过拟合问题。之前我们介绍了我们的目的是要训练一种泛化的模型,那么就要求模型的鲁棒性较强。一个还不错的尝试是在训练神经网络时,让模型的结果不那么依赖某个神经元,因此在训练神经网络过程中,我们每次迭代将隐藏层的一些神经元随机丢弃掉,这样就不会使得我们的模型太依赖某一个神经元,从而使得我们的模型在未知的数据集上或许会有更好的泛化能力。下面我们具体来看dropout的原理。
 # 【深度学习】:《PyTorch入门到项目实战》第九天:Dropout实现
|
机器学习/深度学习 人工智能 TensorFlow
毕业设计(基于TensorFlow的深度学习与研究)之完结篇
本文是我的毕业设计基于TensorFlow的深度学习与研究的完结篇,在本篇推文中,我将分为三个部分去写: 第一部分是对我毕业设计系列推文的总体安排; 第二部分是对我毕业设计的总结概括; 第三部分我将引入一个入门级的案例(借助fashion_mnist数据集),一方面是帮助初学者对深度学习和卷积神经网络有一定的了解,另一方面是此案例与我毕设中的一个案例相似度较高(另外,我毕设中涉及的两个案例的源代码我将在答辩之后更新到Github上)。
毕业设计(基于TensorFlow的深度学习与研究)之完结篇
|
机器学习/深度学习 安全 前端开发
2020,PyTorch真的赶上TensorFlow了吗?
几天前,OpenAI 通过官方博客宣布了「全面转向 PyTorch」的消息,计划将自家平台的所有框架统一为 PyPyTorch。这一消息再次引发了社区关于两个框架优劣的讨论。作为后起之秀,PyTorch 真的已经全面赶超 TensorFlow 了吗?为了研究这个问题,数据科学家 Jeff Hale 从在线职位数量、顶会论文中的出现次数、在线搜索结果、开发者使用情况四个方面对两个框架的现状进行了调研。
108 0
2020,PyTorch真的赶上TensorFlow了吗?
|
算法 算法框架/工具 TensorFlow
带你读《TensorFlow机器学习实战指南(原书第2版)》之三:基于TensorFlow的线性回归
本书由资深数据科学家撰写,从实战角度系统讲解TensorFlow基本概念及各种应用实践。真实的应用场景和数据,丰富的代码实例,详尽的操作步骤,带领读者由浅入深系统掌握TensorFlow机器学习算法及其实现。本书第1章和第2章介绍了关于TensorFlow使用的基础知识,后续章节则针对一些典型算法和典型应用场景进行了实现,并配有较详细的程序说明,可读性非常强。读者如果能对其中代码进行复现,则必定会对TensorFlow的使用了如指掌。
|
机器学习/深度学习 人工智能 JavaScript
AlphaGo Zero你也来造一只,PyTorch实现五脏俱全| 附代码
遥想当年,AlphaGo的Master版本,在完胜柯洁九段之后不久,就被后辈AlphaGo Zero (简称狗零) 击溃了。
2106 0
|
机器学习/深度学习 程序员 TensorFlow