【Datawhale学习笔记】seq2seq代码实现

举报
JeffDing 发表于 2026/01/13 10:34:59 2026/01/13
【摘要】 Seq2Seq 架构RNN 和 LSTM 处理序列数据。这些模型在三类任务中表现出色多对一(Many-to-One):将整个序列信息压缩成一个特征向量,用于文本分类、情感分析等任务。多对多(Many-to-Many, Aligned):为输入序列的每一个词元(Token)都生成一个对应的输出,如词性标注、命名实体识别等。一对多(One-to-Many):从一个固定的输入(如一张图片、一个类...

Seq2Seq 架构

RNN 和 LSTM 处理序列数据。这些模型在三类任务中表现出色

  1. 多对一(Many-to-One):将整个序列信息压缩成一个特征向量,用于文本分类、情感分析等任务。

  2. 多对多(Many-to-Many, Aligned):为输入序列的每一个词元(Token)都生成一个对应的输出,如词性标注、命名实体识别等。

  3. 一对多(One-to-Many):从一个固定的输入(如一张图片、一个类别标签)生成一个可变长度的序列,例如图像描述生成、音乐生成等。

但是,在自然语言处理中,还存在一类更复杂的、被称为多对多(Many-to-Many, Unaligned) 的任务,它们的输入序列和输出序列的长度可能不相等,且元素之间没有严格的对齐关系。最典型的例子就是机器翻译,比如将“我是中国人”(3个词)翻译成 “I am Chinese”(3个词),但 “我爱人工智能”(3个词)翻译成 “I love artificial intelligence”(4个词)

自编码器组成部分

自编码器由两个部分组成:

  1. 编码器: 读取输入数据(如一张图片、一个向量),并将其压缩成一个低维度的、紧凑的潜在表示 (Latent Representation) 。这个过程可以看作是特征提取或数据压缩。
  2. 解码器: 接收这个潜在表示,并尝试将其重构回原始的输入数据。

Seq2Seq 架构

组件

  1. 编码器:扮演“阅读和理解”的角色。它负责接收整个输入序列,并将其信息压缩成一个固定长度的上下文向量(Context Vector) ,通常记为 CC。这个向量就是输入序列的“语义概要”。

  2. 解码器:扮演“组织语言并生成”的角色。它接收上下文向量 CC 作为初始信息,然后逐个生成输出序列中的词元。

核心思想

借鉴人类进行翻译的过程——先完整地阅读并理解源语言的整个句子,形成一个综合的语义表示;然后,基于这个语义表示,开始用目标语言逐词生成译文。

目标:

从 Input 到 Output 的转换,而非重构

编码器 (Encoder)

编码器的任务是生成上下文向量 CC

  • 它可以是一个标准的 RNN(或 LSTM),逐个读取输入序列的词元 x1,x2,,xTx_1, x_2, \dots, x_T
  • 在每个时间步,它都会根据前一时刻的状态和当前输入来更新自身状态。对于标准 RNN,这个过程可以简化为 ht=f(ht1,xt)h_t = f(h_{t-1}, x_t);而对于 LSTM,则同时更新隐藏状态和细胞状态: (ht,ct)=LSTM((ht1,ct1),xt)(h_t, c_t) = \text{LSTM}((h_{t-1}, c_{t-1}), x_t)
  • 当处理完最后一个输入词元 xTx_T后,编码器最终的状态就被用作整个输入序列的上下文向量 CC。对于 LSTM,上下文向量 CC 通常就是最后一个时间步的隐藏状态和细胞状态的元组,即 C=(hT,cT)C = (h_T, c_T)。虽然这是最常见的做法,但上下文向量 CC 也可以由所有时间步的隐藏状态 h1,h2,,hT{h_1, h_2, \dots, h_T} 经过某种变换(如拼接后通过一个线性层、或取平均池化)得到,以期保留更全面的序列信息。在图中,编码器依次处理英文单词 “I”、“love”、“you” 的词嵌入向量,并将最终的状态打包成上下文向量(Context Vector)传递给解码器。

解码器 (Decoder)

  • 解码器同样可以使用一个标准的 RNN(或 LSTM)作为核心,但它扮演的角色是生成器而非信息压缩器,因此其工作流程与编码器有显著差异。
  • 初始化 :解码器的初始状态直接由编码器生成的上下文向量 CC 初始化。对于 LSTM,这意味着初始的隐藏状态和细胞状态 (h0,c0)(h^{\prime}_0, c^{\prime}_0) 都被设置为编码器的最终状态 C=(hT,cT)C=(h_T, c_T)。这相当于将整个输入序列的“语义概要”交给了解码器。
  • 自回归生成 (Auto-regressive Generation) :解码器逐个生成词元。

PyTorch 代码实现与分析

编码器 (Encoder)

class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False
        )

    def forward(self, x):
        # x shape: (batch_size, seq_length)
        embedded = self.embedding(x)
        # 返回最终的隐藏状态和细胞状态作为上下文
        _, (hidden, cell) = self.rnn(embedded)
        return hidden, cell
  1. init:

self.embedding: 定义词嵌入层,将输入的词元ID(整数)映射为稠密的 hidden_size 维度向量。
self.rnn: 定义 LSTM 层。input_size 和 hidden_size 均为 hidden_size,因为词嵌入向量的维度与 LSTM 隐藏状态的维度在此设计中保持一致。此处为简化演示选择单向(bidirectional=False);实际工程中编码器常使用双向 RNN 以获取更充分的上下文,需要将双向状态(如拼接/线性映射)转换为解码器的初始状态。

  1. forward(self, x):

输入 x 是一个形状为 (batch_size, seq_length) 的张量,代表了一批句子的词元ID序列。
embedded = self.embedding(x): 输入经过词嵌入层,形状变为 (batch_size, seq_length, hidden_size)。
_, (hidden, cell) = self.rnn(embedded): self.rnn 处理整个嵌入序列后,会返回两个内容:
outputs: 包含了序列中每一个时间步的隐藏状态。对于编码器而言,中间步骤的输出通常不被使用,因此用 _ 接收。
(hidden, cell): 一个元组,包含了整个序列最后一个时间步的隐藏状态和细胞状态。这正是我们需要的、概括了整个输入序列信息的上下文向量。
return hidden, cell: 函数最终返回这两个状态,作为上下文传递给解码器。这种实现方式对应了 2.2 节中描述的最经典的做法,即直接使用编码器最后一个时间步的状态作为上下文向量 CC

解码器 (Decoder)

class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, x, hidden, cell):
        # x shape: (batch_size),只包含当前时间步的token
        x = x.unsqueeze(1) # -> (batch_size, 1)

        embedded = self.embedding(x)
        # 接收上一步的状态 (hidden, cell),计算当前步
        outputs, (hidden, cell) = self.rnn(embedded, (hidden, cell))

        predictions = self.fc(outputs.squeeze(1)) # -> (batch_size, vocab_size)
        return predictions, hidden, cell
  1. init:

self.embedding 和 self.rnn: 与编码器中的定义类似。
self.fc: 增加了一个全连接层(Linear),它的作用是将 LSTM 输出的 hidden_size 维度的隐藏状态,映射到 vocab_size 维度的向量上。这个向量的每一个元素对应词汇表中一个词的得分(logit),后续可以通过 Softmax 函数转换为概率。

  1. forward(self, x, hidden, cell):

这是一个单步的前向传播函数,其输入 x 是一个形状为 (batch_size,) 的张量,仅包含当前时间步的词元ID。
x = x.unsqueeze(1): 为了适应 nn.Embedding 和 nn.LSTM 对输入形状(需要有序列长度维度)的要求,需要给 x 增加一个长度为1的“伪序列”维度,使其形状变为 (batch_size, 1)。
embedded = self.embedding(x): 词元经过嵌入,形状变为 (batch_size, 1, hidden_size)。
outputs, (hidden, cell) = self.rnn(embedded, (hidden, cell)): 解码器的 RNN 接收两个输入:当前步的嵌入向量 embedded,以及上一步传递过来的隐藏状态 (hidden, cell)。它只进行一步计算,然后返回当前步的输出 outputs 和更新后的状态 (hidden, cell)。
predictions = self.fc(outputs.squeeze(1)): RNN 的输出 outputs 形状是 (batch_size, 1, hidden_size),需要用 squeeze(1) 移除长度为1的序列维度,再送入全连接层,得到形状为 (batch_size, vocab_size) 的最终预测。
return predictions, hidden, cell: 返回当前步的预测,以及更新后的状态,用于下一步的计算。

Seq2Seq 包装模块

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        hidden, cell = self.encoder(src)

        # 第一个输入是 <SOS>
        input = trg[:, 0]

        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t, :] = output

            # 决定是否使用 Teacher Forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            # 如果 teacher_force,下一个输入是真实值;否则是模型的预测值
            input = trg[:, t] if teacher_force else top1

        return outputs

forward 函数接收源序列 src (形状 (batch_size, src_len)) 和目标序列 trg (形状 (batch_size, trg_len)),并模拟了训练过程中的一个批次计算:

  1. 初始化:
    outputs = torch.zeros(…): 创建一个形状为 (batch_size, trg_len, vocab_size) 的全零张量,用于存储解码器在每一个时间步的输出 logits。
    hidden, cell = self.encoder(src): 调用编码器处理源序列 src,得到初始的上下文向量。hidden 和 cell 的形状均为 (num_layers, batch_size, hidden_size)。

  2. 启动解码
    input = trg[:, 0] 取出目标序列 trg 的第一个词元(通常是 SOS 标志),作为解码器循环的起始输入。

  3. 循环解码:
    for t in range(1, trg_len): 循环从第二个词元(索引为1)开始,直到目标序列结束。
    output, hidden, cell = self.decoder(input, hidden, cell): 调用解码器执行单步计算。它接收形状为 (batch_size) 的 input 和上一时刻的状态,返回当前步的预测 output 和更新后的状态。
    outputs[:, t, :] = output: 将当前步的预测存入 outputs 张量中。

  4. 教师强制:
    teacher_force = random.random() < teacher_forcing_ratio: 以一定的概率决定是否启用教师强制。
    top1 = output.argmax(1): 找出当前步预测概率最高的词元ID,得到形状为 (batch_size) 的张量 top1。
    input = trg[:, t] if teacher_force else top1: 这是教师强制的关键。根据 teacher_force 的值,选择真实的下一个词元 trg[:, t] 或模型自己的预测 top1 作为下一步的输入。无论哪种情况,下一步的 input 形状都将是 (batch_size)。

  5. 返回: 最终返回 outputs 张量,其形状为 (batch_size, trg_len, vocab_size),用于后续与真实标签计算损失。

高效的推理实现

# ... 在 Seq2Seq 类中 ...
    def greedy_decode(self, src, max_len=12, sos_idx=1, eos_idx=2):
        """推理模式下的高效贪心解码。"""
        self.eval()
        with torch.no_grad():
            hidden, cell = self.encoder(src)
            trg_indexes = [sos_idx]
            for _ in range(max_len):
                # 1. 输入只有上一个时刻的词元
                trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(self.device)
                
                # 2. 解码一步,并传入上一步的状态
                output, hidden, cell = self.decoder(trg_tensor, hidden, cell)
                
                # 3. 获取当前步的预测,并更新状态用于下一步
                pred_token = output.argmax(1).item()
                trg_indexes.append(pred_token)
                if pred_token == eos_idx:
                    break
        return trg_indexes

这种方式通过状态的传递与更新避免了重复计算:

  1. hidden, cell = self.encoder(src): 在循环开始前,只调用一次编码器,获取初始上下文。

  2. 循环内部:
    trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(self.device): 每次的输入仅仅是上一步生成的最后一个词元 trg_indexes[-1],而不是整个序列。
    output, hidden, cell = self.decoder(trg_tensor, hidden, cell): 将这个单词元输入和上一步的 hidden, cell 状态送入解码器。解码器只执行一步计算,并返回新的 hidden, cell 状态。
    这两个新状态会覆盖旧的状态变量,并在下一次循环中被用作输入。
    通过这种方式,信息流和状态在时间步之间平稳地传递,每个时间步都只进行一次必要的计算。

上下文向量的另一种用法

class DecoderAlt(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(DecoderAlt, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        # 主要改动 1: RNN的输入维度是 词嵌入+上下文向量
        self.rnn = nn.LSTM(
            input_size=hidden_size + hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, x, hidden_ctx, hidden, cell):
        x = x.unsqueeze(1)
        embedded = self.embedding(x)

        # 主要改动 2: 将上下文向量与当前输入拼接
        # 这里简单地取编码器最后一层的 hidden state 作为上下文代表
        context = hidden_ctx[-1].unsqueeze(1).repeat(1, embedded.shape[1], 1)
        rnn_input = torch.cat((embedded, context), dim=2)

        # 解码器的初始状态 hidden, cell 在第一步可设为零;之后需传递并更新上一步状态
        outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        predictions = self.fc(outputs.squeeze(1))
        return predictions, hidden, cell
  1. init:
  • self.rnn = nn.LSTM(…): 这里的主要改动是 input_size=hidden_size + hidden_size。因为在每个时间步,输入给 LSTM 的不再仅仅是词嵌入向量(维度 hidden_size),而是词嵌入向量与上下文向量(维度也是 hidden_size)拼接后的新向量,因此输入维度加倍。
  1. forward(self, x, hidden_ctx, hidden, cell):
  • context = hidden_ctx[-1].unsqueeze(1).repeat(1, embedded.shape[1], 1): 这一步是为了准备用于拼接的上下文向量。
  • rnn_input = torch.cat((embedded, context), dim=2): 核心操作,在最后一个维度(特征维度)上,将词嵌入向量和上下文向量拼接起来,形成 RNN 的最终输入。
  • outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)): 将拼接后的向量送入 RNN。注意,这里传入的 (hidden, cell) 是解码器自身的上一步状态(初始是零向量),而不是编码器传来的上下文 hidden_ctx。上下文信息已经通过输入端注入了。

完整程序

import torch
import torch.nn as nn
import random

torch.manual_seed(42)

# 1. 全局配置
batch_size = 8
src_len = 10
trg_len = 12
src_vocab_size = 100
trg_vocab_size = 120
hidden_size = 64
num_layers = 2
sos_idx = 1  # 句子起始标记索引
eos_idx = 2  # 句子结束标记索引


# 2. 模型定义

class Encoder(nn.Module):
    """编码器: 读取输入序列,输出上下文向量(隐藏状态和细胞状态)。"""
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size,      # 输入特征维度
            hidden_size=hidden_size,     # 隐藏状态维度
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False
        )

    def forward(self, x):
        embedded = self.embedding(x)
        _, (hidden, cell) = self.rnn(embedded)
        return hidden, cell


class Decoder(nn.Module):
    """解码器(标准实现): 接收上一个预测的token和当前状态,单步输出预测和新状态。"""
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, x, hidden, cell):
        x = x.unsqueeze(1)
        embedded = self.embedding(x)
        outputs, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        predictions = self.fc(outputs.squeeze(1))
        return predictions, hidden, cell


class Seq2Seq(nn.Module):
    """Seq2Seq 包装模块: 管理 Encoder 和 Decoder。"""
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        """训练模式下的前向传播,使用 Teacher Forcing。"""
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        hidden, cell = self.encoder(src)
        input = trg[:, 0]

        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t, :] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1
        return outputs

    def greedy_decode(self, src, max_len=trg_len):
        """推理模式下的高效贪心解码。"""
        self.eval()
        with torch.no_grad():
            hidden, cell = self.encoder(src)
            trg_indexes = [sos_idx]
            for _ in range(max_len):
                trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(self.device)
                output, hidden, cell = self.decoder(trg_tensor, hidden, cell)
                pred_token = output.argmax(1).item()
                trg_indexes.append(pred_token)
                if pred_token == eos_idx:
                    break
        return trg_indexes


# 3. 变体模型定义

class DecoderAlt(nn.Module):
    """解码器变体: 不用上下文初始化状态,而是在每步将其作为输入。"""
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(DecoderAlt, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
        # 注意:这里的输入维度是两个 hidden_size 拼接而成
        self.rnn = nn.LSTM(
            input_size=hidden_size + hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, x, hidden_ctx, hidden, cell):
        x = x.unsqueeze(1)
        embedded = self.embedding(x)
        context = hidden_ctx[-1].unsqueeze(1).repeat(1, embedded.shape[1], 1)
        rnn_input = torch.cat((embedded, context), dim=2)
        outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        predictions = self.fc(outputs.squeeze(1))
        return predictions, hidden, cell


# 4. 解码策略

def alternative_greedy_decode(encoder, decoder, src, device, max_len=trg_len):
    """配合 DecoderAlt 的解码实现。"""
    with torch.no_grad():
        hidden_ctx, cell_ctx = encoder(src)
        trg_indexes = [sos_idx]
        # 初始化解码器的"真实"状态为0
        batch_size = src.shape[0]
        hidden = torch.zeros(num_layers, batch_size, hidden_size).to(device)
        cell = torch.zeros(num_layers, batch_size, hidden_size).to(device)
        
        for _ in range(max_len):
            trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
            output, hidden, cell = decoder(trg_tensor, hidden_ctx, hidden, cell)
            pred_token = output.argmax(1).item()
            trg_indexes.append(pred_token)
            if pred_token == eos_idx:
                break
    return trg_indexes


# 5. 主流程

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 统一初始化模型 ---
    encoder = Encoder(src_vocab_size, hidden_size, num_layers).to(device)
    decoder = Decoder(trg_vocab_size, hidden_size, num_layers).to(device)
    model = Seq2Seq(encoder, decoder, device).to(device)

    # --- 统一创建伪数据 ---
    src = torch.randint(1, src_vocab_size, (batch_size, src_len)).to(device)
    trg = torch.randint(1, trg_vocab_size, (batch_size, trg_len)).to(device)

    # =========================================
    # 1: 标准训练 & 高效推理
    # =========================================
    print("\n" + "="*25 + " 1: 标准模式 " + "="*25)
    # 训练过程模拟 (Teacher Forcing)
    model.train()
    outputs = model(src, trg, teacher_forcing_ratio=0.8)
    print(f"训练模式输出张量形状: {outputs.shape}")
    # 推理过程模拟 (高效的自回归)
    prediction = model.greedy_decode(src[0:1, :])
    print(f"高效推理的预测结果: {prediction}")

    # =========================================
    # 2: 上下文向量的另一种用法
    # =========================================
    print("\n" + "="*23 + " 2: 上下文变体用法 " + "="*23)
    decoder_alt = DecoderAlt(trg_vocab_size, hidden_size, num_layers).to(device)
    
    prediction_alt = alternative_greedy_decode(encoder, decoder_alt, src[0:1, :], device)
    print(f"变体用法预测结果: {prediction_alt}")

参考资料

https://github.com/datawhalechina/base-llm/blob/main/docs/chapter4/10_seq2seq.md

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。