7.3 LSTM网络#
在前面两节内容中,我们详细介绍了RNN模型的原理以及在PyTorch框架中的使用方法。虽然理论上RNN模型在处理序列数据方面具有着很好的效果,但在处理长序列数据时RNN模型可能会出现梯度消失或爆炸的情况,进而导致模型无法学习到长期依赖的关系。在这样的背景下,基于RNN模型改进的长短期记忆网络(Long Short-Term Memory, LSTM)[1]便应运而生了。
7.3.1 LSTM动机#
根据7.1节内容可知,RNN模型出现的动机便是为了解决时序数据的特征编码问题,但在实际情况中由于时序数据的时间步长较多因而又导致了新问题的出现。例如模型在对一个时序长度为50的序列进行特征提取时,第50个时刻的隐含向量$h_{50}$便会依赖之前所有时刻的信息。在这样的情景下,一方面由7.1.6节内容可知模型在训练过程中极易出现梯度消失或爆炸的情况;另一方面由于历史时刻和当前时刻的状态信息得不到有效筛选,使得模型难以学到真正对下游任务有用的信息,而这也被称之为长期依赖(Long-Term Dependencies)问题。
基于这样的动机, 霍赫赖特等人在1997年提出了基于门控单元的长短期记忆网络。LSTM模型的设计思想在于通过引入门控机制(Gating)和记忆状态(Memory State)来解决上述两个问题。门控机制可以控制信息的流动决定哪些信息应该被保留,哪些应该被遗忘,以及哪些信息应该输入给下一个记忆单元以便模型可以更好地捕捉序列的长期依赖关系。同时,记忆状态则是类似于残差模块中的连接,使得经过筛选后的记忆状态能够直接输入到下一个记忆单元,从而缓解了RNN中的梯度消失或爆炸问题。
7.3.2 LSTM结构#
同RNN模型类似,LSTM模型也是一个在时间维度进行展开的循环结构。在原始的RNN模型中,每个重复的记忆单元内只包含一个简单的网络层,而对于LSTM模型来说每个记忆单元包含有4个不同的网络层,并通过不同的方式进行相互作用以此得到当前时候的输出结果。如图7-6所示便是LSTM模型的网络结构示意图。

在图7-6中,上方贯穿每个记忆单元的便是LSTM中引入的记忆状态$C_t$,同时$f_t$ 、$i_t$、$o_t$和$\tilde{C_t}$ 分别表示遗忘门(Forget Gate)、输入门(Input Gate)、输出门(Output Gate)和输入层对应的输出结果,$h_t$表示当前时候的输出结果。从图7-6可以看出,LSTM中的每个记忆单元都会通过3个门控结构($f_t$、$i_t$和$o_t$)来完成对信息流动的控制与筛选。每个门控结构都是根据同样的输入经Sigmoid函数作用后得到一系列取值在区间$[0,1]$ 里的值,其中0表示将流经的所有信息进行完全抑制,而1则表示对流经的信息不做任何处理。
1. 遗忘门
遗忘门主要用于对历史记忆状态$C_{t-1}$进行筛选,选择性的遗忘或者保留部分历史信息,使得真正有用的信息能够贯穿LSTM 网络。

如图7-7所示,在LSTM中第1步需要做的就是确定应该丢弃哪些信息,$t-1$时刻的输出$h_{t-1}$和当前时刻的输入$x_t$经过遗忘门之后,后续再通过按位乘作用于$C_{t-1}$便完成了对历史记忆状态的筛选,其具体计算过程为
$$ f_t=\sigma([h_{t-1},x_t]W_f+b_f)\tag{7-9} $$
其中$\sigma$为Sigmoid函数,$[a,b]$ 表示将$a,b$两个向量进行堆叠组合,$x_t$的形状为[batch_size,input_size],$h_{t-1}$的形状为[batch_size,hidden_size],$[h_{t-1},x_t]$的形状为[batch_size,input_size+hidden_size],$W_f$的形状为[input_size+hidden_size,hidden_size],$b_f$的形状为[hidden_size],$f_t$的形状为[batch_size,hidden_size]。
2.输入门
输入门主要用于对输入信息$\tilde{C_t}$进行筛选,仅让其中部分信息流入到当前时刻,然后再与历史记忆状态融合形成新的记忆状态$C_t$。

如图7-8所示,在LSTM中第2步需要做的便是通过输入门对当前时刻的输入信息进行筛选,其具体计算过程为
$$ \begin{aligned} i_t&=\sigma([h_{t-1},x_t]W_i+b_i)\\[2ex] \tilde{C_t}&=\tanh([h_{t-1},x_t]W_c+b_c)\\[2ex] \end{aligned}\tag{7-10} $$
其中$W_i$和$W_c$的形状均为[input_size+hidden_size,hidden_size],$b_i$和$b_c$的形状均为[hidden_size],$i_t$和$\tilde{C}_t$的形状均为[batch_size,hidden_size]。
进一步对历史记忆状态$C_{t-1}$进行更新,计算过程为
$$ C_t=f_t\odot C_{t-1}\oplus i_t\odot\tilde{C_t}\tag{7-11} $$
其中$\odot$表示按位乘,$\oplus$表示按位加,$C_{t-1}$的形状为[batch_size,hidden_size]。
3. 输出门