老饼讲解:一步一步上手深度学习
好了,我们如果要用pytorch来实现RNN,就要借助pytorch提供的RNN层来进行计算。
本节我们详细说说pytorch中的RNN层到底是怎么回事,其实都不想写这文章的,但自己每次在pytorch中用到RNN时都有些蒙,这张文章算写给我自己的吧。
在说pytorch的RNN层之前,我们先来了解pytorch对时序数据的约定。
我们知道,RNN计算时呢,就是逐时刻计算,也就是拿0时刻的所有样本数据进行计算,得到0时刻的输出,然后用时刻1的所有样本数据计算,得到时刻1的输出,。。。如此类似,如下图所示:

好了,记住它是逐时刻吞吐所有样本数据就行了,为什么要逐时刻呢,因为每个时刻可能依赖于前一时刻的输出。总之,记住它是逐时刻吞吐所有样本数据就行了.
所以呢,在pytorch中,用三维来表示时序数据,其中,第0维代表时刻,第1维代表样本,第2维代表数据。
示图如下:

很容易理解为什么是这样的格式,因为它按第0维就可以逐时刻对数据进行计算了。
好了,pytorch的RNN隐层,就是干着如下的事情:

如图,RNN层按公式,逐时刻计算当前时刻的隐层输出。因此,我们按上图所示的格式,给它一个输入X,它就会给我们计算出每个时刻的隐层输出了。
快来试一下吧!pytorch的RNN层使用代码示例如下:
# 本代码用于展示:pytorch的RNN层的使用
# 本代码来自《老饼讲解-深度学习》www.bbblearn.com
import torch
x = torch.rand(3,2,4,dtype=torch.float32) # 生成输入数据,3个时刻,每时刻2个样本,样本有4个变量
rnn = torch.nn.RNN(4,3) # RNN隐层,输入节点4个,隐层输出节点3个
h_init = torch.zeros(1,2,3) # 隐层的初始值,2是样本个数,3是隐层节点个数
h,_ = rnn(x,h_init) # 对输入进行池化
print("\n输入:",x) # 打印输入
print("\nRNN层的输出:",h) # 打印输出结果运行结果如下:

好了,我们来解说一下代码吧,我们这里的代码是按着2.1中的图来实现的,不妨将代码与图配合着来理解。


首先,第4行是我们的输入,由于有3个时刻,每时刻2个样本,一个样本有4个变量,所以生成的x的shape为:[3,2,4]。
接下来,第5行初始化RNN的结构,由图可知,我们的输入变量有4个,隐层变量(即隐节点)有3个,所以这里初始化为torch.nn.RNN(4,3)。
好了,第6行是隐层的初始值,隐层的初始值其实就是与单时刻隐层的输出一样。

理论上来说,初始H应该是一个单时刻H,也就是一个二维矩阵。但我们输入的初始H是一个三维矩阵,即在单时刻H上外套一维。即理论上h_init = torch.zeros(2,3),但代码中用的是h_init = torch.zeros(1,2,3)。为什么要增加一维呢?这是为了兼容DRNN(深度RNN)模型,所以我们这里先不管,给它套多一维就好了。
好了,接下来第7行利用RNN层计算输出,那么它就会给我们返回3个时刻所有样本的隐层输出了,对着图来看就行了。现在问题来了,我们代码中用的是h,_ = rnn(x,h_init) ,这个h我们可以理解,就是隐层的输出嘛,_是什么呢?_其实也是一个变量,当函数有这个输出,而我们不想用它时,就将它用_来命名。好了,为什么我们不想用第二个输出呢,因为它也是兼容DRNN的。
最后,代码第8、9行打印出了输入x,以及RNN层的输出h,我们可以对着结果看一下是不是跟理解中一样。

好了,我们补充说说上面为什么隐层初始值需要多套一维,以及RNN隐层的第二个输出是什么。其实它们都是为DRNN(深度RNN)服务的,简单看下DRNN:

如图,DRNN就是有多个隐层的RNN。那么,初始化隐层时,就需要输入各个隐层的初始值了,同时,我们可能需要DRNN每个隐层最后一时刻的输出,所以,RNN层的第二个输出就是它了。
好了,直接简单上个代码吧
# 本代码用于展示:pytorch的多隐层RNN
# 本代码来自《老饼讲解-深度学习》www.bbblearn.com
import torch
x = torch.rand(3,2,4,dtype=torch.float32) # 生成输入数据,3个时刻,每时刻2个样本,样本有4个变量
rnn = torch.nn.RNN(4,3,2) # RNN隐层,输入节点4个,隐层神经元3个,有2个隐层
h_init = torch.zeros(2,2,3) # 2个隐层的初始值,2是样本个数,3是隐层节点个数
h,ht = rnn(x,h_init) # 对输入进行池化
print("\n输入:",x) # 打印输入
print("\nRNN层的输出:",h) # 打印输出结果
print("\nRNN层的输出:",ht) # 打印输出结果运行结果如下:

好了,具体就不解说了,真的好累,等要用到DRNN时再仔细看一看代码的注释和运行结果就清楚了。
这里主要梳理了pytorch的RNN层是怎么使用的,包括它的输入数据、输出数据的意思是什么。
建议使用RNN层时,对输入输出不明白再来细细看,毕竟,真的很烧脑、很啰嗦~
评论