chap9 现代循环神经网络(2) LSTM

长期以来,隐变量模型存在着长期信息保存和短期输入跳跃的问题。解决这一问题的最早方法之一是长短期存储器(long short-term memory, LSTM) 。它有许多与门控循环单元一样的属性。

有趣的是,长短期记忆网络(LSTM)的设计比门控循环单元稍微复杂一些,却比门控循环单元(GRU)早诞生了近20年。

门控记忆单元

长短期记忆网络引入了 存储单元(memory cell),或简称为 单元(cell)。有些文献认为存储单元是隐藏状态的一种特殊类型,它们与隐藏状态具有相同的形状,其设计目的是用于记录附加的信息。

为了控制存储单元,我们需要许多门。

  • 输入门(ItI_t):决定是否忽略掉输入数据
  • 遗忘门(FtF_t):将值朝 0 减少
  • 输出门(OtO_t):决定是否使用隐变量的值

数学描述,假设有 hh 个隐藏单元,批量大小为 nn,输入数为 dd。因此,输入为 XtRn×d\mathbf{X}_t \in \mathbb{R}^{n \times d},前一时间步的隐藏状态为 Ht1Rn×h\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}。相应地,时间步 tt 的门被定义如下:输入门是 ItRn×h\mathbf{I}_t \in \mathbb{R}^{n \times h},遗忘门是 FtRn×h\mathbf{F}_t \in \mathbb{R}^{n \times h},输出门是 OtRn×h\mathbf{O}_t \in \mathbb{R}^{n \times h}。它们的计算方法如下:

It=σ(XtWxi+Ht1Whi+bi),Ft=σ(XtWxf+Ht1Whf+bf),Ot=σ(XtWxo+Ht1Who+bo),\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o), \end{aligned}

其中 Wxi,Wxf,WxoRd×h\mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h}Whi,Whf,WhoRh×h\mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h}权重参数bi,bf,boR1×h\mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h}偏置参数

候选记忆单元

接下来,设计记忆单元。由于还没有指定各种门的操作,所以先介绍 候选记忆单元(candidate memory cell)C~tRn×h\tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h}

它的计算与上面描述的三个门的计算类似,但是使用 tanh\tanh 函数作为激活函数,函数的值范围为 (1,1)(-1, 1)。下面导出在时间步 tt 处的方程:

C~t=tanh(XtWxc+Ht1Whc+bc),\tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),

其中 WxcRd×h\mathbf{W}_{xc} \in \mathbb{R}^{d \times h}WhcRh×h\mathbf{W}_{hc} \in \mathbb{R}^{h \times h} 是权重参数,bcR1×h\mathbf{b}_c \in \mathbb{R}^{1 \times h} 是偏置参数。

候选记忆单元的图示如 :

记忆单元(辅助)

在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。

类似地,在长短期记忆网络中,也有两个门用于这样的目的:

输入门 It\mathbf{I}_t 控制采用多少来自 C~t\tilde{\mathbf{C}}_t 的新数据,

遗忘门 Ft\mathbf{F}_t 控制保留了多少旧记忆单元 Ct1Rn×h\mathbf{C}_{t-1} \in \mathbb{R}^{n \times h} 的内容。使用与前面相同的按元素做乘法的技巧,得出以下更新公式:

Ct=FtCt1+ItC~t.\mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.

如果遗忘门始终为 11 且输入门始终为 00,则过去的记忆单元 Ct1\mathbf{C}_{t-1} 将随时间被保存并传递到当前时间步。引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。

这样就得到了流程图,如:

隐藏状态

最后,我们需要定义如何计算隐藏状态 HtRn×h\mathbf{H}_t \in \mathbb{R}^{n \times h}。这就是输出门发挥作用的地方。

在长短期记忆网络中,它仅仅是记忆单元的 tanh\tanh 的门控版本。这就确保了 Ht\mathbf{H}_t 的值始终在区间 (1,1)(-1, 1) 内。

Ht=Ottanh(Ct).\mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).

只要输出门接近 11,我们就能够有效地将所有记忆信息传递给预测部分,而对于输出门接近 00,我们只保留存储单元内的所有信息,并且没有进一步的过程需要执行。

总结

It=σ(XtWxi+Ht1Whi+bi),Ft=σ(XtWxf+Ht1Whf+bf),Ot=σ(XtWxo+Ht1Who+bo),C~t=tanh(XtWxc+Ht1Whc+bc),Ct=FtCt1+ItC~t,Ht=Ottanh(Ct).\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o),\\ \tilde{\mathbf{C}}_t &= \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),\\ \mathbf{C}_t &= \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t,\\ \mathbf{H}_t &= \mathbf{O}_t \odot \tanh(\mathbf{C}_t). \end{aligned}