编码器(encoder):它接受一个长度可变的序列作为输入,并将其转换为具有固定形状的编码状态
解码器(decoder):它将固定形状的编码状态映射到长度可变的序列。
架构图示:
一个模型被分为两块:
编码器
1 2 3 4 5 6 7 8 9 10 11
| from torch import nn
class Encoder(nn.Module): """编码器-解码器结构的基本编码器接口。""" def __init__(self, **kwargs): super(Encoder, self).__init__(**kwargs)
def forward(self, X, *args): raise NotImplementedError
|
解码器
1 2 3 4 5 6 7 8 9 10 11
| class Decoder(nn.Module): """编码器-解码器结构的基本解码器接口。""" def __init__(self, **kwargs): super(Decoder, self).__init__(**kwargs)
def init_state(self, enc_outputs, *args): raise NotImplementedError
def forward(self, X, state): raise NotImplementedError
|
合并编码器和解码器
1 2 3 4 5 6 7 8 9 10 11 12
| class EncoderDecoder(nn.Module): """编码器-解码器结构的基类。""" def __init__(self, encoder, decoder, **kwargs): super(EncoderDecoder, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder
def forward(self, enc_X, dec_X, *args): enc_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_outputs, *args) return self.decoder(dec_X, dec_state)
|
小结
- “编码器-解码器”结构可以将长度可变的序列作为输入和输出,因此适用于机器翻译等序列转换问题。
- 编码器将长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。
- 解码器将具有固定形状的编码状态映射为长度可变的序列。