使用PyTorch从零开始构建Elman循环神经网络

简介: 循环神经网络是如何工作的?如何构建一个Elman循环神经网络?在这里,教你手把手创建一个Elman循环神经网络进行简单的序列预测。

本文以最简单的RNNs模型为例:Elman循环神经网络,讲述循环神经网络的工作原理即便是你没有太多循环神经网络(RNNs)的基础知识,也可以很容易的理解。为了让你更好的理解RNNs我们使用Pytorch张量包和autograd从头开始构建Elman循环神经网络。该文中完整代码在Github上是可实现的。

在这里,假设你对前馈神经网络略有了解。Pytorchautograd更为详细的内容请查看我的其他教程

68363d99d874dfb5a2f53084e522c73a9f5f34cd

Elman循环神经网络

Jeff Elman首次提出了Elman循环神经网络,并发表在论文《Finding structure in time》中:它只是一个三层前馈神经网络,输入层由一个输入神经元x1一组上下文神经元单元{c1 ... cn}组成隐藏层前一时间步的神经元作为上下文神经元的输入在隐藏层中每个神经元有一个上下文神经元。由于前一时间步的状态作为输入的一部分,因此我们可以说Elman循环神经网络拥有一定的内存——上下文神经元代表一个内存。

预测正弦波

现在,我们来训练RNNs学习正弦函数。在训练过程中,一次为模型提供一个数据,这就是为什么我们只需要一个输入神经元x1,并且我们希望在下一时间步预测该值。输入序列x20个数据组成,并且目标序列与输入序列相同。

5a336e128286628cf86977c17783f34d70d06e35 

模型实现

首先导入包。

e5048e42574aeaea8f76d0af60084ea0f24d3c9c 

接下来,设置模型的超参数。设置输入层的大小为76个上下文神经元和1个输入神经元),seq_length用来定义输入和目标序列的长度。

c8be968e1760e5b306d6e3a5bec48b5d46c12026 

生成训练数据:x是输入序列,y是目标序列。

6fa08fc0e07c757d55b30fd33e064bb0347ae3b7 

创建两个权重矩阵。大小为(input_sizehidden_size)的矩阵w1用于隐藏连接的输入,大小为(hidden_sizeoutput_size)的矩阵w2用于隐藏连接的输出。 用零均值的正态分布对权重矩阵进行初始化。

d8a4877fbd8b331ef7755cc0a7b4b1cec27df859 

定义forward方法,其参数为input向量、context_state向量和两个权重矩阵,连接inputcontext_state创建xh向量。对xh向量和权重矩阵w1执行点积运算,然后用tanh函数作为非线性函数,在RNNstanhsigmoid效果要好。 然后对新的context_state和权重矩阵w2再次执行点积运算。 我们想要预测连续值,因此这个阶段不使用任何非线性。

请注意,context_state向量将在下一时间步填充上下文神经元。 这就是为什么我们要返回context_state向量和out

4ca6517a08cbd409471972aab15631a2447a6b98 

训练

训练循环的结构如下:

1.外循环遍历每个epochepoch被定义为所有的训练数据全部通过训练网络一次。在每个epoch开始时,将context_state向量初始化为0

2.内部循环遍历序列中的每个元素。执行forward方法进行正向传递,该方法返回predcontext_state,将用于下一个时间步。然后计算均方误差(MSE)用于预测连续值。执行backward()方法计算梯度,然后更新权重w1w2。每次迭代中调用zero_()方法清除梯度,否则梯度将会累计起来。最后将context_state向量包装放到新变量中,以将其与历史值分离开来。

4085d3f491eab971b4f5612b58d3e9d11b68b7b4 

训练期间产生的输出显示了每个epoch的损失是如何减少的,这是一个好的衡量方式。损失的逐渐减少则意味着我们的模型正在学习。

23b7ff14535e2eb453038bfc222503c1cec5cc26 

预测

一旦模型训练完毕,我们就可以进行预测。在序列的每一步我们只为模型提供一个数据,并要求模型在下一个步预测一个值。

f25e66a8ab5e47ae383972b9cc3a80e64420e433 

预测结果如下图所示:黄色圆点表示预测值,蓝色圆点表示实际值,二者基本吻合,因此模型的预测效果非常好。

a16cb76e976c4b32d1db8e55780e5a18816f67dd

结论

在这里,我们使用了Pytorch从零开始构建一个基本的RNNs模型,并且学习了如何将RNNs应用于简单的序列预测问题。


数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧! 

以上为译文。

本文由北邮@爱可可-爱生活 老师推荐,阿里云云栖社区组织翻译。

文章原标题《Introduction to Recurrent Neural Networks in Pytorch》,译者:Mags,审校:袁虎。

文章为简译,更为详细的内容,请查看原文 

 

相关文章
|
17天前
|
安全 网络安全 数据安全/隐私保护
网络堡垒的构建者:洞悉网络安全与信息安全的深层策略
【4月更文挑战第9天】在数字化时代,数据成为了新的价值核心。然而,随之而来的是日益复杂的网络安全威胁。从漏洞利用到信息泄露,从服务中断到身份盗用,攻击手段不断演变。本文深入剖析了网络安全的关键组成部分:识别和防范安全漏洞、加密技术的应用以及提升个体和企业的安全意识。通过探讨这些领域的最佳实践和最新动态,旨在为读者提供一套全面的策略工具箱,以强化他们在数字世界的防御能力。
|
1月前
|
机器学习/深度学习 算法 PyTorch
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
232 1
|
1天前
|
云安全 安全 网络安全
云端防御战线:构建云计算环境下的网络安全体系
【4月更文挑战第25天】 随着企业数字化转型的加速,云计算以其灵活性、成本效益和可扩展性成为众多组织的首选技术平台。然而,云服务的广泛采用也带来了前所未有的安全挑战,特别是在数据保护、隐私合规以及网络攻击防护等方面。本文将深入探讨云计算环境中的网络安全策略,从云服务模型出发,分析不同服务层次的安全责任划分,并针对网络威胁提出综合性的防御措施。此外,文中还将讨论信息加密、身份验证、入侵检测等关键技术在维护云安全中的作用,以期为读者提供一套全面的云安全解决方案框架。
|
2天前
|
移动开发 Java Android开发
构建高效Android应用:采用Kotlin协程优化网络请求
【4月更文挑战第24天】 在移动开发领域,尤其是对于Android平台而言,网络请求是一个不可或缺的功能。然而,随着用户对应用响应速度和稳定性要求的不断提高,传统的异步处理方式如回调地狱和RxJava已逐渐显示出局限性。本文将探讨如何利用Kotlin协程来简化异步代码,提升网络请求的效率和可读性。我们将深入分析协程的原理,并通过一个实际案例展示如何在Android应用中集成和优化网络请求。
|
7天前
|
存储 安全 网络安全
构建坚固的防线:云计算环境下的网络安全策略
【4月更文挑战第19天】 随着企业纷纷迁移至云平台,云计算已成为现代信息技术架构的核心。然而,数据存储与处理的云端化也带来了前所未有的安全挑战。本文深入探讨了在复杂多变的云环境中,如何实施有效的网络安全措施,确保信息安全和业务连续性。通过分析云服务模型、网络威胁以及加密技术,提出了一系列切实可行的安全策略,旨在帮助组织构建一个既灵活又强大的防御体系。
16 1
|
7天前
|
监控 安全 算法
数字堡垒的构建者:网络安全与信息保护的现代策略
【4月更文挑战第19天】在信息化快速发展的今天,网络安全和信息安全已成为维护社会稳定、保障个人隐私和企业商业秘密的关键。本文将深入探讨网络安全漏洞的成因、加密技术的进展以及提升安全意识的重要性,旨在为读者提供一套综合性的网络防护策略,以应对日益猖獗的网络威胁。
7 1
|
9天前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch小技巧:使用Hook可视化网络层激活(各层输出)
这篇文章将演示如何可视化PyTorch激活层。可视化激活,即模型内各层的输出,对于理解深度神经网络如何处理视觉信息至关重要,这有助于诊断模型行为并激发改进。
11 1
|
10天前
|
机器学习/深度学习 资源调度 数据可视化
使用Python和Keras进行主成分分析、神经网络构建图像重建
使用Python和Keras进行主成分分析、神经网络构建图像重建
13 1
|
11天前
|
数据采集 API 数据安全/隐私保护
畅游网络:构建C++网络爬虫的指南
本文介绍如何使用C++和cpprestsdk库构建高效网络爬虫,以抓取知乎热点信息。通过亿牛云爬虫代理服务解决IP限制问题,利用多线程提升数据采集速度。示例代码展示如何配置代理、发送HTTP请求及处理响应,实现多线程抓取。注意替换有效代理服务器参数,并处理异常。
畅游网络:构建C++网络爬虫的指南
|
19天前
|
机器学习/深度学习 人工智能 运维
构建未来:AI驱动的自适应网络安全防御系统
【4月更文挑战第7天】 在数字时代的浪潮中,网络安全已成为维系信息完整性、保障用户隐私和确保商业连续性的关键。传统的安全防御策略,受限于其静态性质和对新型威胁的响应迟缓,已难以满足日益增长的安全需求。本文将探讨如何利用人工智能(AI)技术打造一个自适应的网络安全防御系统,该系统能够实时分析网络流量,自动识别并响应未知威胁,从而提供更为强大和灵活的保护机制。通过深入剖析AI算法的核心原理及其在网络安全中的应用,我们将展望一个由AI赋能的、更加智能和安全的网络环境。
28 0