chap9 现代循环神经网络(7) 束搜索

sec_seq2seq 中,我们逐个预测输出序列,直到预测序列中出现特定的序列结束词元“<eos>”。在
本节中,我们将首先介绍 贪心搜索(greedy search)策略,并探讨其存在的问题,然后对比其他替代策略:穷举搜索(exhaustive search)和束搜索(beam search)。

在正式介绍贪心搜索之前,让我们定义搜索问题。

\在任意时间步 tt',解码器输出 yty_{t'} 的概率取决于时间步 tt' 之前的输出子序列 y1,,yt1y_1, \ldots, y_{t'-1} 和对输入序列的信息进行编码得到的上下文变量 c\mathbf{c}

为了量化计算成本,用 Y\mathcal{Y} 表示输出词汇表,其中包含“<eos>”,所以这个词汇集合的基数 Y\left|\mathcal{Y}\right| 就是词汇表的大小。我们还将输出序列的最大词元数指定为 TT'。因此,我们的目标是从所有 O(YT)\mathcal{O}(\left|\mathcal{Y}\right|^{T'}) 个可能的输出序列中寻找理想的输出。当然,对于所有输出序列,这些序列中包含的“<eos>”及其之后的部分将在实际输出中丢弃。

贪心搜索

对于输出序列的任何时间步 tt',我们都将基于贪心搜索从 Y\mathcal{Y} 中找到具有最高条件概率的词元,即:

yt=*argmaxyYP(yy1,,yt1,c)y_{t'} = \operatorname*{argmax}_{y \in \mathcal{Y}} P(y \mid y_1, \ldots, y_{t'-1}, \mathbf{c})

一旦输出序列包含了“<eos>”或者达到其最大长度 TT',则输出完成。

问题

现实中,最优序列(optimal sequence)应该是最大化 t=1TP(yty1,,yt1,c)\prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c}) 值的输出序列,这是基于输入序列生成输出序列的条件概率。不幸的是,无法保证通过贪心搜索得到的是最优序列。

穷举搜索

如果目标是获得最优序列,我们可以考虑使用 穷举搜索(exhaustive search):穷举地列举所有可能的输出序列及其条件概率,然后输出条件概率最高的一个。

虽然我们可以使用穷举搜索来获得最优序列,但其计算量 O(YT)\mathcal{O}(\left|\mathcal{Y}\right|^{T'}) 可能高的过分。

例如,当 Y=10000|\mathcal{Y}|=10000T=10T'=10 时,我们需要评估 1000010=104010000^{10} = 10^{40} 序列,这几乎是不可能的!

另一方面,贪心搜索的计算量是 O(YT)\mathcal{O}(\left|\mathcal{Y}\right|T'):通常它要显著地小于穷举搜索。例如,当 Y=10000|\mathcal{Y}|=10000T=10T'=10 时,我们只需要评估 10000×10=10510000\times10=10^5 个序列。

束搜索

那么该选取哪种序列搜索策略呢?如果只有正确性最重要,则显然是穷举搜索。如果计算成本最重要,则显然是贪心搜索。
而束搜索的实际应用则介于这两个极端之间。

束搜索(beam search)是贪心搜索的一个改进版本。

它有一个超参数,名为 束宽(beam size)kk

在时间步 11,我们选择具有最高条件概率的 kk 个词元。这 kk 个词元将分别是 kk 个候选输出序列的第一个词元。

在随后的每个时间步,基于上一时间步的 kk 个候选输出序列,我们将继续从 kYk\left|\mathcal{Y}\right| 个可能的选择中挑出具有最高条件概率的 kk 个候选输出序列。

是从所有结果中选两个最大的,而不是平行筛选。

得最终候选输出序列集合后,我们选择以下得分最高的序列作为输出序列:

1LαlogP(y1,,yL)=1Lαt=1LlogP(yty1,,yt1,c),\frac{1}{L^\alpha} \log P(y_1, \ldots, y_{L}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c}),

:eqlabel:eq_beam-search-score

其中 LL 是最终候选序列的长度,α\alpha 通常设置为 0.750.75。因为一个较长的序列在 :eqref:eq_beam-search-score 的求和中会有更多的对数项,因此分母中的 LαL^\alpha 用于惩罚长序列。

束搜索的计算量为 O(kYT)\mathcal{O}(k\left|\mathcal{Y}\right|T'),这个结果介于贪心搜索和穷举搜索之间。实际上,贪心搜索可以看作是一种束宽为 11 的特殊类型的束搜索。通过灵活地选择束宽,束搜索可以在正确率和计算成本之间进行权衡。

小结

  • 序列搜索策略包括贪心搜索、穷举搜索和束搜索。
  • 束搜索通过灵活选择束宽,在正确率和计算成本之间找到平衡。