自注意力机制有自己的一套选择 query 和 value 的方法。将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。
具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention)
自注意力
-
给定序列 x1,…,xn,其中任意 xi∈Rd (1≤i≤n)。
-
自注意力池化层将 xi 作为 key,value,query 来对序列抽取特征得到 y1,…,yn
yi=f(xi,(x1,x1),…,(xn,xn))∈Rd
跟 CNN RNN 对比
🏷fig_cnn-rnn-self-attention
|
CNN |
RNN |
自注意力 |
计算复杂度 |
O(knd^2) |
O(nd^2) |
O(n^2d) |
并行度 |
O(n) |
O(1) |
O(n) |
最长路径 |
O(n/k) |
O(n) |
O(1) |
考虑一个卷积核大小为 k 的卷积层。我们将在后面的章节中提供关于使用卷积神经网络处理序列的更多详细信息。目前,我们只需要知道,由于序列长度是 n,输入和输出的通道数量都是 d,所以卷积层的计算复杂度为 O(knd2)。如 :numref:fig_cnn-rnn-self-attention
所示,卷积神经网络是分层的,因此为有 O(1) 个顺序操作,最大路径长度为 O(n/k)。例如,x1 和 x5 处于 :numref:fig_cnn-rnn-self-attention
中卷积核大小为 3 的双层卷积神经网络的感受野内。
当更新循环神经网络的隐藏状态时,d×d 权重矩阵和 d 维隐藏状态的乘法计算复杂度为 O(d2)。由于序列长度为 n,因此循环神经网络层的计算复杂度为 O(nd2)。根据 :numref:fig_cnn-rnn-self-attention
,有 O(n) 个顺序操作无法并行化,最大路径长度也是 O(n)。
在自注意力中,查询、键和值都是 n×d 矩阵。考虑 :eqref:eq_softmax_QK_V
中缩放的”点-积“注意力,其中 n×d 矩阵乘以 d×n 矩阵,然后输出的 n×n 矩阵乘以 n×d 矩阵。因此,自注意力具有 O(n2d) 计算复杂性。正如我们在 :numref:fig_cnn-rnn-self-attention
中看到的那样,每个词元都通过自注意力直接连接到任何其他词元。因此,有 O(1) 个顺序操作可以并行计算,最大路径长度也是 O(1)。
总而言之,卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。
位置编码
在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。
为了使用序列的顺序信息,我们通过在输入表示中添加 位置编码(positional encoding)来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。
接下来,我们描述的是基于正弦函数和余弦函数的固定位置编码 。
假设输入表示 X∈Rn×d 包含一个序列中 n 个词元的 d 维嵌入表示。位置编码使用相同形状的位置嵌入矩阵 P∈Rn×d 输出 X+P,矩阵第 i 行、第2j列和2j 列上的元素为:
pi,2jpi,2j+1=sin(100002j/di),=cos(100002j/di).
:eqlabel:eq_positional-encoding-def
上述公式是对绝对位置信息做了投影变换从而得到了相对位置信息。
绝对位置信息:对位置的二进制编码,例如,0:000,1:001,2:010…
这是因为对于任何确定的位置偏移 δ,位置 i+δ 处的位置编码可以线性投影位置 i 处的位置编码来表示。
这种投影的数学解释是,令 ωj=1/100002j/d,对于任何确定的位置偏移 δ,:eqref:eq_positional-encoding-def
中的任何一对 (pi,2j,pi,2j+1) 都可以线性投影到 (pi+δ,2j,pi+δ,2j+1):
===[cos(δωj)−sin(δωj)sin(δωj)cos(δωj)][pi,2jpi,2j+1][cos(δωj)sin(iωj)+sin(δωj)cos(iωj)−sin(δωj)sin(iωj)+cos(δωj)cos(iωj)][sin((i+δ)ωj)cos((i+δ)ωj)][pi+δ,2jpi+δ,2j+1],
2×2 投影矩阵不依赖于任何位置的索引 i。
总结
- 在自注意力中,查询、键和值都来自同一组输入。
- 自注意力完全并行、最长路径为 1、但对长序列的计算复杂度较高
- 为了使用使得自注意力能够记住序列的顺序信息,我们可以通过在输入表示中添加位置编码来注入绝对的或相对的位置信息。