chap9 现代循环神经网络(1) GRU

关注一个序列

  • 不是每个观察值都同等重要
  • 想只记住相关的观察需要
    • 能遗忘的机制:重置门
    • 能关注的机制:更新门

重置门和更新门

数学描述,对于给定的时间步 tt,假设输入是一个小批量 XtRn×d\mathbf{X}_t \in \mathbb{R}^{n \times d} (样本个数:nn,输入个数:dd),上一个时间步的隐藏状态是 Ht1Rn×h\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}(隐藏单元个数:hh)。然后,重置门 RtRn×h\mathbf{R}_t \in \mathbb{R}^{n \times h} 和更新门 ZtRn×h\mathbf{Z}_t \in \mathbb{R}^{n \times h} 的计算如下:

Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),\begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned}

其中 Wxr,WxzRd×h\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}Whr,WhzRh×h\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}权重参数br,bzR1×h\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}偏置参数。【都是可学习参数】

请注意,在求和过程中会触发广播机制。我们使用 sigmoid 函数将输入值转换到区间 (0,1)(0, 1)

候选隐藏状态(重置门的应用)

接下来,让我们将重置门 Rt\mathbf{R}_t 与常规隐状态更新机制集成,得到在时间步 tt 的候选隐藏状态 H~tRn×h\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}

H~t=tanh(XtWxh+(RtHt1)Whh+bh),\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),

:eqlabel:gru_tilde_H

其中 WxhRd×h\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}WhhRh×h\mathbf{W}_{hh} \in \mathbb{R}^{h \times h} 是权重参数,bhR1×h\mathbf{b}_h \in \mathbb{R}^{1 \times h} 是偏置项,符号 \odot 是哈达码乘积(按元素乘积)运算符。在这里,我们使用 tanh 非线性激活函数来确保候选隐藏状态中的值保持在区间 (1,1)(-1, 1) 中。

计算的结果是 候选者(candidate),因为我们仍然需要结合更新门的操作。与 rnn 相比 :eqref:gru_tilde_H 中的 Rt\mathbf{R}_tHt1\mathbf{H}_{t-1} 的元素相乘可以减少以往状态的影响。

每当重置门 Rt\mathbf{R}_t 中的项接近 11 时,我们恢复一个如 rnn 中的普通的循环神经网络。

对于重置门 Rt\mathbf{R}_t 中所有接近 00 的项,候选隐藏状态是以 Xt\mathbf{X}_t 作为输入的多层感知机的结果【即,丢弃过往信息】。因此,任何预先存在的隐藏状态都会被 重置 为默认值。

隐藏状态(更新门的应用)

最后,我们需要结合更新门 Zt\mathbf{Z}_t 的效果。

这确定新的隐藏状态 HtRn×h\mathbf{H}_t \in \mathbb{R}^{n \times h} 在多大程度上就是旧的状态 Ht1\mathbf{H}_{t-1} ,以及对新的候选状态 H~t\tilde{\mathbf{H}}_t 的使用量。

更新门 Zt\mathbf{Z}_t 仅需要在 Ht1\mathbf{H}_{t-1}H~t\tilde{\mathbf{H}}_t 之间进行按元素的凸组合就可以实现这个目标。这就得出了门控循环单元的最终更新公式:

Ht=ZtHt1+(1Zt)H~t.\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.

每当更新门 Zt\mathbf{Z}_t 接近 11 时,我们就只保留旧状态。此时,来自 Xt\mathbf{X}_t 的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 tt

相反,当 Zt\mathbf{Z}_t 接近 00 时,新的隐藏状态 Ht\mathbf{H}_t 就会接近候选的隐藏状态 H~t\tilde{\mathbf{H}}_t

这些设计可以帮助我们处理循环神经网络中的梯度消失问题,并更好地捕获时间步距离很长的序列的依赖关系。

例如,如果整个子序列的所有时间步的更新门都接近于 11,则无论序列的长度如何,在序列起始时间步的旧隐藏状态都将很容易保留并传递到序列结束。

总结

门控循环单元具有以下两个显著特征:

  • 重置门有助于捕获序列中的短期依赖关系。
  • 更新门有助于捕获序列中的长期依赖关系。