【Datawhale学习笔记】注意力机制及Transform代码实践

举报
JeffDing 发表于 2026/01/13 12:02:57 2026/01/13
【摘要】 注意力机制 设计原理在解码器生成每一个词元时,不再依赖一个固定的上下文向量,而是允许它“回头看”一遍完整的输入序列,并根据当前解码的需求,自主地为输入序列的每个部分分配不同的注意力权重,然后基于这些权重将输入信息加权求和,生成一个动态的、专属当前时间步的上下文向量。通俗地理解为从“一言以蔽之”到“择其要者而观之”的转变 注意力机制详解 三部曲计算相似度使用解码器上一时刻的隐藏状态 ht−1...

注意力机制

设计原理

在解码器生成每一个词元时,不再依赖一个固定的上下文向量,而是允许它“回头看”一遍完整的输入序列,并根据当前解码的需求,自主地为输入序列的每个部分分配不同的注意力权重,然后基于这些权重将输入信息加权求和,生成一个动态的、专属当前时间步的上下文向量。
通俗地理解为从“一言以蔽之”到“择其要者而观之”的转变

注意力机制详解

三部曲

  1. 计算相似度
    使用解码器上一时刻的隐藏状态 ht1h^{\prime}_{t-1} 与编码器的每一个隐藏状态 hjh_j 计算一个分数,这个分数衡量了在当前解码时刻,应当对第 jj 个输入词元投入多少“关注”。

etj=score(ht1,hj)e_{tj} = \text{score}(h^{\prime}_{t-1}, h_j)

这个分数越高,代表关联性越强。计算这个分数的方式有很多种,例如简单的点积、或者引入一个可学习的神经网络层。

  1. 计算注意力权重
    得到输入序列所有位置的注意力分数 (et1,et2,,et,Tx)(e_{t1}, e_{t2}, \dots, e_{t,T_x}) 后,为了将它们转换成一种“权重”的表示,可使用 Softmax 函数对其进行归一化。这样,就能得到一组总和为 1、且均为正数的注意力权重 (αt1,αt2,,αt,Tx)(\alpha_{t1}, \alpha_{t2}, \dots, \alpha_{t,T_x})

αtj=softmax(etj)=exp(etj)i=1Txexp(eti)\alpha_{tj} = \text{softmax}(e_{tj}) = \frac{\exp(e_{tj})}{\sum_{i=1}^{T_x} \exp(e_{ti})}

这组权重 αt\alpha_t 构成了一个概率分布,清晰地表明了在当前解码步骤 tt,注意力应该如何分配在输入序列的各个位置上。

  1. 加权求和,生成上下文向量
    最后,使用上一步得到的注意力权重 αtj\alpha_{tj},对编码器的所有隐藏状态 hjh_j 进行加权求和,从而得到当前解码时刻 tt 专属的上下文向量 CtC_t

Ct=j=1TxαtjhjC_t = \sum_{j=1}^{T_x} \alpha_{tj} h_j

这个 CtC_t 向量,由于是根据当前解码需求动态生成的,它比原始 Seq2Seq 的那个固定向量 CC 包含了更具针对性的信息。

PyTorch 实现与代码解析

编码器

class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True  # 使用双向LSTM
        )
        self.fc = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, x):
        embedded = self.embedding(x)
        outputs, (hidden, cell) = self.rnn(embedded)
        
        # 将双向RNN的输出通过线性层降维,使其与解码器维度匹配
        outputs = torch.tanh(self.fc(outputs))

        return outputs, hidden, cell
  • bidirectional=True:启用双向 LSTM,使原始 RNN outputs 维度变为 (batch, src_len, hidden_size * 2)。
  • self.fc:定义一个线性层,将拼接后的双向输出映射回 hidden_size 维度;经过 self.fc 和 tanh 后,outputs 维度回到 (batch, src_len, hidden_size),方便后续计算。
  • return outputs, …:返回降维后的所有时间步输出 outputs (作为后续的 K 和 V),以及原始的最终状态 hidden 和 cell。

注意力模块的两种实现

无参数的注意力

class AttentionSimple(nn.Module):
    """1: 无参数的注意力模块"""
    def __init__(self, hidden_size):
        super(AttentionSimple, self).__init__()
        # 确保缩放因子是一个 non-learnable buffer
        self.register_buffer("scale_factor", torch.sqrt(torch.FloatTensor([hidden_size])))

    def forward(self, hidden, encoder_outputs):
        # hidden shape: (num_layers, batch_size, hidden_size)
        # encoder_outputs shape: (batch_size, src_len, hidden_size)
        
        # Q: 解码器最后一层的隐藏状态
        query = hidden[-1].unsqueeze(1)  # -> (batch, 1, hidden)
        # K/V: 编码器的所有输出
        keys = encoder_outputs  # -> (batch, src_len, hidden)

        # energy shape: (batch, 1, src_len)
        energy = torch.bmm(query, keys.transpose(1, 2)) / self.scale_factor
        
        # attention_weights shape: (batch, src_len)
        return torch.softmax(energy, dim=2).squeeze(1)

带参数的注意力

class AttentionParams(nn.Module):
    """2: 带参数的注意力模块"""
    def __init__(self, hidden_size):
        super(AttentionParams, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[1]
        hidden_last_layer = hidden[-1].unsqueeze(1).repeat(1, src_len, 1)
        
        energy = torch.tanh(self.attn(torch.cat((hidden_last_layer, encoder_outputs), dim=2)))
        attention = torch.sum(self.v * energy, dim=2)
        
        return torch.softmax(attention, dim=1)

通用解码器

class DecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, attention_module):
        super(DecoderWithAttention, self).__init__()
        self.attention = attention_module
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size * 2,  # 输入维度是 词嵌入(hidden_size) + 上下文向量(hidden_size)
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden, cell, encoder_outputs):
        embedded = self.embedding(x.unsqueeze(1))

        # 1. 计算注意力权重
        # a shape: [batch, src_len]
        a = self.attention(hidden, encoder_outputs).unsqueeze(1)
        
        # 2. 计算上下文向量
        context = torch.bmm(a, encoder_outputs)

        # 3. 将上下文向量与当前输入拼接
        rnn_input = torch.cat((embedded, context), dim=2)

        # 4. 传入RNN解码
        outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        
        # 5. 预测输出
        predictions = self.fc(outputs.squeeze(1))
        
        return predictions, hidden, cell

Seq2Seq 包装模块

class Seq2Seq(nn.Module):
    """带注意力的Seq2Seq"""
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        encoder_outputs, hidden, cell = self.encoder(src)

        # 适配Encoder(双向)和Decoder(单向)的状态维度
        hidden = hidden.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)
        cell = cell.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)

        input = trg[:, 0]
        for t in range(1, trg_len):
            # 在循环的每一步,都将 encoder_outputs 传递给解码器
            # 这是 Attention 机制能够"回顾"整个输入序列的关键
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[:, t, :] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1
            
        return outputs

完整代码

import torch
import torch.nn as nn
import random

torch.manual_seed(42)

# 1. 全局配置
batch_size = 8
src_len = 10
trg_len = 12
src_vocab_size = 100
trg_vocab_size = 120
hidden_size = 64
num_layers = 2
sos_idx = 1
eos_idx = 2


# 2. 模型定义

class Encoder(nn.Module):
    """编码器: 读取输入序列,输出所有时间步的特征以及最终状态。"""
    def __init__(self, vocab_size, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True  # 使用双向LSTM
        )
        self.fc = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, x):
        embedded = self.embedding(x)
        outputs, (hidden, cell) = self.rnn(embedded)
        outputs = torch.tanh(self.fc(outputs))
        return outputs, hidden, cell

class DecoderWithAttention(nn.Module):
    """带注意力的通用解码器"""
    def __init__(self, vocab_size, hidden_size, num_layers, attention_module):
        super(DecoderWithAttention, self).__init__()
        self.attention = attention_module
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=hidden_size
        )
        self.rnn = nn.LSTM(
            input_size=hidden_size * 2,  # 输入维度是 词嵌入 + 上下文向量
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden, cell, encoder_outputs):
        x = x.unsqueeze(1)
        embedded = self.embedding(x)
        
        a = self.attention(hidden, encoder_outputs).unsqueeze(1)
        context = torch.bmm(a, encoder_outputs)
        rnn_input = torch.cat((embedded, context), dim=2)

        outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        predictions = self.fc(outputs.squeeze(1))
        
        return predictions, hidden, cell

class Seq2Seq(nn.Module):
    """带注意力的Seq2Seq"""
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        encoder_outputs, hidden, cell = self.encoder(src)
        
        hidden = hidden.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)
        cell = cell.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)

        input = trg[:, 0]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[:, t, :] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1
            
        return outputs

    def greedy_decode(self, src, max_len=trg_len):
        self.eval()
        with torch.no_grad():
            encoder_outputs, hidden, cell = self.encoder(src)
            hidden = hidden.view(self.encoder.rnn.num_layers, 2, src.shape[0], -1).sum(axis=1)
            cell = cell.view(self.encoder.rnn.num_layers, 2, src.shape[0], -1).sum(axis=1)

            trg_indexes = [sos_idx]
            for _ in range(max_len):
                input_tensor = torch.LongTensor([trg_indexes[-1]]).to(self.device)
                output, hidden, cell = self.decoder(input_tensor, hidden, cell, encoder_outputs)
                pred_token = output.argmax(1).item()
                trg_indexes.append(pred_token)
                if pred_token == eos_idx:
                    break
        return trg_indexes

# 3. Attention 定义

class AttentionSimple(nn.Module):
    """1: 无参数的注意力模块"""
    def __init__(self, hidden_size):
        super(AttentionSimple, self).__init__()
        # 确保缩放因子是一个 non-learnable buffer
        self.register_buffer("scale_factor", torch.sqrt(torch.FloatTensor([hidden_size])))

    def forward(self, hidden, encoder_outputs):
        # hidden shape: (num_layers, batch_size, hidden_size)
        # encoder_outputs shape: (batch_size, src_len, hidden_size)
        
        # Q: 解码器最后一层的隐藏状态
        query = hidden[-1].unsqueeze(1)  # -> (batch, 1, hidden)
        # K/V: 编码器的所有输出
        keys = encoder_outputs  # -> (batch, src_len, hidden)

        # energy shape: (batch, 1, src_len)
        energy = torch.bmm(query, keys.transpose(1, 2)) / self.scale_factor
        
        # attention_weights shape: (batch, src_len)
        return torch.softmax(energy, dim=2).squeeze(1)

class AttentionParams(nn.Module):
    """2: 带参数的注意力模块"""
    def __init__(self, hidden_size):
        super(AttentionParams, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[1]
        hidden_last_layer = hidden[-1].unsqueeze(1).repeat(1, src_len, 1)
        
        energy = torch.tanh(self.attn(torch.cat((hidden_last_layer, encoder_outputs), dim=2)))
        attention = torch.sum(self.v * energy, dim=2)
        
        return torch.softmax(attention, dim=1)

# 4. 主流程
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 统一创建编码器和伪数据 ---
    encoder = Encoder(src_vocab_size, hidden_size, num_layers).to(device)
    src = torch.randint(1, src_vocab_size, (batch_size, src_len)).to(device)
    trg = torch.randint(1, trg_vocab_size, (batch_size, trg_len)).to(device)

    # =========================================
    # 1: 无参数的基础 Attention
    # =========================================
    print("\n" + "="*20 + " 1: 无参数的基础 Attention " + "="*20)
    attention_simple = AttentionSimple(hidden_size).to(device)
    decoder_simple = DecoderWithAttention(trg_vocab_size, hidden_size, num_layers, attention_simple).to(device)
    model_simple = Seq2Seq(encoder, decoder_simple, device).to(device)
    
    model_simple.train()
    outputs_simple = model_simple(src, trg)
    print(f"模型结构:\n{model_simple}")
    print(f"\n训练模式输出张量形状: {outputs_simple.shape}")
    
    prediction_simple = model_simple.greedy_decode(src[0:1, :])
    print(f"推理的预测结果: {prediction_simple}")

    # =========================================
    # 2: 带参数的 Attention
    # =========================================
    print("\n" + "="*20 + " 2: 带参数的 Attention " + "="*20)
    attention_params = AttentionParams(hidden_size).to(device)
    decoder_params = DecoderWithAttention(trg_vocab_size, hidden_size, num_layers, attention_params).to(device)
    model_params = Seq2Seq(encoder, decoder_params, device).to(device)

    model_params.train()
    outputs_params = model_params(src, trg)
    print(f"模型结构:\n{model_params}")
    print(f"\n训练模式输出张量形状: {outputs_params.shape}")
    
    prediction_params = model_params.greedy_decode(src[0:1, :])
    print(f"推理的预测结果: {prediction_params}")

注意力机制的类型

Soft Attention vs. Hard Attention

  • Soft Attention:这就是前文一直在详细讨论的机制。它为输入序列的所有位置都计算一个注意力权重,这些权重是 0 到 1 之间的浮点数(经 Softmax 归一化),然后进行加权求和。这种方式的优点是模型是端到端可微的,可以使用标准的梯度下降法进行训练。其缺点是在处理非常长的序列时,计算开销会很大。因为解码的每一步,都需要计算当前状态与所有输入状态的相似度。

  • Hard Attention3:与 Soft Attention 对所有输入进行加权不同,Hard Attention 在每一步只选择一个最相关的输入位置。可以看作是一种“非 0 即 1”的注意力分配,即选中的位置权重为 1,其他所有位置的权重均为 0。这样做的好处是计算量大大减少,因为不再需要进行全面的加权求和。但它的缺点也很突出:选择过程是离散的、不可微的,因此无法使用常规的反向传播算法进行训练,通常需要借助强化学习等更复杂的技巧。

Global Attention vs. Local Attention

  • Global Attention (全局注意力):其思想与 Soft Attention 基本一致,即在计算注意力时,会考虑编码器的所有隐藏状态。

  • Local Attention (局部注意力):这是一种介于 Soft Attention 和 Hard Attention 之间的折中方案。能够减少计算量,但又不像 Hard Attention 那样极端。其核心思想是,在每个解码时间步,只关注输入序列的一个局部窗口。它的工作流程通常是:

  1. 预测对齐位置:首先,模型需要预测一个当前解码步最关注的源序列位置 ptp_t。这个位置可以通过一个小型神经网络,仅依赖于当前解码器状态 hth^\prime_t 来预测,从而避免了与所有编码器状态进行比较,降低了计算成本。预测公式可以设计为: pt=Txsigmoid(Wpht+bp)p_t = T_x \cdot \text{sigmoid}(W_p h'_t + b_p),其中 TxT_x 是源序列长度, WpW_pbpb_p 是可学习的参数。

  2. 定义窗口:以预测出的 ptp_t 为中心,定义一个大小为 2D+12D+1 的窗口,其中 DD 是一个超参数。

  3. 局部计算:最后,模型只在这个窗口内的编码器状态上应用 Soft Attention 机制,计算权重并生成上下文向量。

Transformer

Transformer的提出

2017年,Google 的研究团队发表了一篇名为《Attention Is All You Need》的论文,提出了一种全新的架构——Transformer

自注意力机制

从根本上说,要让模型理解一段文本,就需要提取其“序列特征”,即将文本中所有词元的信息以某种方式整合起来。RNN 通过依次传递隐藏状态来顺序地整合信息,而 Transformer 则选择了一条截然不同的道路。其核心是 自注意力机制。它不再依赖于顺序计算,而是将提取序列特征的过程看作是输入序列“自己对自己进行注意力计算”。序列中的每个词元都会“审视”序列中的所有其他词元,来动态地计算出最能代表当前词元上下文含义的新表示。与上一节介绍的交叉注意力不同,在自注意力中,Query、Key、Value 均来源于同一个输入序列。

PyTorch 实现自注意力

class SelfAttention(nn.Module):
    """自注意力模块"""
    def __init__(self, hidden_size):
        super(SelfAttention, self).__init__()
        self.hidden_size = hidden_size
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, x):
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.hidden_size)
        attention_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, v)
        
        return context
  • init: 初始化了三个 nn.Linear 层,它们分别对应将输入映射到 Q, K, V 空间的权重矩阵 WQ,WK,WVW^Q, W^K, W^V
    forward:
    • q_linear(x), k_linear(x), v_linear(x):将形状为 [batch_size, seq_len, hidden_size] 的输入张量 x 分别通过三个线性层,一次性地为序列中的所有词元计算出 Q, K, V 矩阵。
    • torch.matmul(q, k.transpose(-2, -1)): 这是实现并行计算的核心。通过将 K 矩阵的最后两个维度转置(seq_len, hidden_size -> hidden_size, seq_len),再与 Q 矩阵相乘,直接得到了一个 [batch_size, seq_len, seq_len] 的分数矩阵。该矩阵中的 scores[b, i, j] 代表了批次 b 中第 i 个词元对第 j 个词元的注意力分数。
    • / math.sqrt(self.hidden_size):执行缩放操作,防止梯度消失。
    • torch.softmax(scores, dim=-1):对分数的最后一个维度(seq_len)进行 Softmax,得到归一化的注意力权重。
    • torch.matmul(attention_weights, v):将权重矩阵与 V 矩阵相乘,完成了对所有词元的 Value 向量的加权求和,得到最终的上下文感知表示。

多头注意力机制

仅仅用一组 WQ,WK,WVW^Q, W^K, W^V 矩阵进行一次自注意力计算,相当于只从一个“视角”来审视文本内在的关系。然而,文本中的关系是多层次的,例如,一组参数可能学会了关注代词(如 “它” 指向谁)的关系,但可能忽略了动作的执行者(主谓宾)等其他类型的关系。

为了让模型能够综合利用从不同维度和视角提取出的信息,Transformer 引入了多头注意力机制 (Multi-Head Attention)。其思想非常直接:并行地执行多次自注意力计算,每一次计算都是一个独立的“头 (Head)”。每个头都拥有一组自己专属的 WiQ,WiK,WiVW^Q_i, W^K_i, W^V_i 权重矩阵,并且可以学习去关注一种特定类型的上下文关系。

PyTorch 实现多头注意力

class MultiHeadSelfAttention(nn.Module):
    """多头自注意力模块"""
    def __init__(self, hidden_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        self.wo = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        # 拆分多头
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 并行计算注意力
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attention_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, v)
        
        # 合并多头结果
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
        
        # 输出层
        output = self.wo(context)
        
        return output

Transformer 整体结构

编码器(Encoder)

编码器的作用是“理解”和“消化”输入的整个序列,为序列中的每个词元生成一个富含上下文信息的表示。一个标准的编码器层由两个主要的子层构成,分别是多头自注意力层(Multi-Head Self-Attention Layer)和位置前馈网络(Position-wise Feed-Forward Network)。每个子层的输出都经过了**残差连接(Add)与层归一化(Norm)**处理。所以,一个编码器层内部的数据流可以表示为 x -> Sublayer1(x) -> Add & Norm -> Sublayer2(…) -> Add & Norm。

关键特性:

  • 注意力类型:编码器中的多头注意力层是双向的自注意力。这意味着在计算时,序列中的任何一个词元都可以“看到”序列中的所有其他词元(包括它自己、它前面的和它后面的)。
  • 功能:由于其双向性,编码器非常擅长理解完整的输入文本,并为每个词元生成一个深度融合了上下文信息的表示。
  • 应用:通过大量堆叠编码器层而构建的模型(Encoder-Only 架构),如 BERT,在文本分类、命名实体识别等自然语言理解(NLU)任务上取得了巨大成功。

解码器(Decoder)

解码器的作用是基于编码器对原始输入的理解,并结合已经生成的部分,来逐个生成下一个词元。为了完成这个更复杂的任务,一个标准的解码器层(Decoder Layer)比编码器层多了一个注意力子层,总共包含三个子层。分别是带掩码的多头自注意力层(Masked Multi-Head Self-Attention Layer)、交叉注意力层(Cross-Attention Layer)和位置前馈网络(Position-wise Feed-Forward Network)。同样,解码器的每个子层也都采用了残差连接和层归一化。

Transformer 代码实践

搭建整体框架

  • Embedding 层:将输入的 token ID 转换为连续的向量表示,并加上位置编码以保留序列顺序信息。
  • Encoder 堆叠:由 NN 个 EncoderLayer 串联而成,负责深度提取和理解输入序列的特征。
  • Decoder 堆叠:由 NN 个 DecoderLayer 串联而成,负责基于 Encoder 的输出逐步生成目标序列。
  • Output 层:一个线性层,将解码器的最终输出映射回词表大小,用于计算下一个词的概率分布。
# src/transformer.py
import torch.nn as nn
from .pos import PositionalEncoding  # 稍后实现
# ... 导入其他组件

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, dim=512, n_heads=8, n_layers=6, ...):
        super().__init__()

        self.dim = dim
        # 1. 嵌入层与位置编码
        # src_embedding: 将源语言序列映射为向量 (Encoder输入)
        self.src_embedding = nn.Embedding(src_vocab_size, dim)
        # tgt_embedding: 将目标语言序列映射为向量 (Decoder输入)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, dim)
        self.pos_encoder = PositionalEncoding(dim, max_seq_len)
        self.dropout = nn.Dropout(dropout)
        
        # 2. 编码器与解码器堆叠
        # 使用 ModuleList 来存储层列表,支持按索引访问和自动注册参数
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(dim, n_heads, hidden_dim, dropout) for _ in range(n_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(dim, n_heads, hidden_dim, dropout) for _ in range(n_layers)
        ])
        
        # 3. 输出头
        self.output = nn.Linear(dim, tgt_vocab_size)

    def forward(self, src, tgt):
        # 1. 生成掩码 (Padding Mask & Causal Mask)
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        
        # 2. 编码器前向传播
        enc_output = self.encode(src, src_mask)
        
        # 3. 解码器前向传播
        dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
        
        # 4. 输出 Logits
        return self.output(dec_output)
        return logits

实现核心组件

位置编码

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    """
    正弦位置编码
    Transformer 论文中使用固定公式计算位置编码,不涉及可学习参数。
    """
    def __init__(self, dim, max_seq_len=5000):
        super().__init__()
        
        # 创建一个足够长的 PE 矩阵 [max_seq_len, dim]
        pe = torch.zeros(max_seq_len, dim)
        
        # 生成位置索引 [0, 1, ..., max_seq_len-1] -> [max_seq_len, 1]
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # 计算分母中的 div_term: 10000^(2i/dim) = exp(2i * -log(10000)/dim)
        # 这种对数变换的计算方式在数值上更稳定
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        
        # 填充 PE 矩阵
        # 偶数维度用 sin,奇数维度用 cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 增加 batch 维度: [1, max_seq_len, dim] 以便广播
        pe = pe.unsqueeze(0)
        
        # 注册为 buffer
        # register_buffer 的作用是告诉 PyTorch:
        # 1. 'pe' 是模型状态的一部分,会随模型保存和加载 (state_dict)。
        # 2. 'pe' 不是模型参数 (Parameter),优化器更新时不会更新它。
        self.register_buffer('pe', pe)

在前向传播中,我们的任务就是将位置编码加到输入的词嵌入上。由于我们预先生成的 pe 矩阵可能比当前的输入序列 x 要长,所以需要根据 x 的实际长度对 pe 进行切片。

class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_seq_len=5000):
        ...

    def forward(self, x):
        """
        Args:
            x: 输入的词嵌入序列 [batch_size, seq_len, dim]
        Returns:
            加上位置编码后的序列 [batch_size, seq_len, dim]
        """
        # 截取与输入序列长度对应的位置编码并相加
        # x.size(1) 是 seq_len
        # self.pe 的形状是 [1, max_seq_len, dim],切片后会自动广播到 batch_size
        x = x + self.pe[:, :x.size(1), :]
        return x

最后,我们可以编写一段简单的测试代码来验证维度是否正确。

if __name__ == "__main__":
    # 准备参数
    batch_size, seq_len, dim = 2, 10, 512
    max_seq_len = 100
    
    # 初始化模块
    pe = PositionalEncoding(dim, max_seq_len)
    
    # 准备输入
    x = torch.zeros(batch_size, seq_len, dim) # 输入为0,直接观察PE值
    
    # 前向传播
    output = pe(x)
    
    # 验证输出
    print("--- PositionalEncoding Test ---")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")

多头注意力

# src/attention.py
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        
        # 定义 Wq, Wk, Wv 矩阵
        # 这里我们使用一个大的线性层一次性计算所有头的 Q, K, V
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, dim)
        self.wv = nn.Linear(dim, dim)
        
        # 最终输出的线性层 Wo
        self.wo = nn.Linear(dim, dim)
        
        self.dropout = nn.Dropout(dropout)

这部分前向传播的重点是“分头”操作。我们不直接对 [batch, seq_len, dim] 进行计算,而是将其 reshape 为 [batch, n_heads, seq_len, head_dim],这样就可以利用矩阵运算并行地处理所有头。

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.1):
        ...

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 1. 线性投影
        # [batch, seq_len, dim] -> [batch, seq_len, dim]
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        
        # 2. 分头 (Split Heads)
        # 变换形状: [batch, seq_len, n_heads, head_dim] 
        # 然后转置: [batch, n_heads, seq_len, head_dim] 以便并行计算
        q = q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        
        # 3. 计算缩放点积注意力 (Scaled Dot-Product Attention)
        # scores: [batch, n_heads, seq_len, seq_len]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # 4. 应用掩码 (Masking)
        if mask is not None:
            # mask == 0 的位置被填充为负无穷,Softmax 后变为 0
            scores = scores.masked_fill(mask == 0, float('-inf'))
            
        # 5. Softmax 与加权求和
        attn_weights = torch.softmax(scores, dim=-1)
        
        if self.dropout is not None:
             attn_weights = self.dropout(attn_weights)
             
        # context: [batch, n_heads, seq_len, head_dim]
        context = torch.matmul(attn_weights, v)
        
        # 6. 合并多头 (Concat Heads)
        # [batch, n_heads, seq_len, head_dim] -> [batch, seq_len, n_heads, head_dim]
        # -> [batch, seq_len, dim]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)
        
        # 7. 输出层投影
        output = self.wo(context)
        
        return output

# 单元测试
if __name__ == "__main__":
    # 准备参数
    batch_size, seq_len, dim = 2, 10, 512
    n_heads = 8
    
    # 初始化模块
    mha = MultiHeadAttention(dim, n_heads)
    
    # 准备输入 (Query, Key, Value 相同)
    x = torch.randn(batch_size, seq_len, dim)
    
    # 前向传播
    output = mha(x, x, x)
    
    # 验证输出
    print("--- MultiHeadAttention Test ---")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")

前馈神经网络

# src/ffn.py
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim)  # 升维
        self.w2 = nn.Linear(hidden_dim, dim)  # 降维
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # 线性变换 -> ReLU -> Dropout -> 线性变换
        return self.w2(self.dropout(torch.relu(self.w1(x))))

if __name__ == "__main__":
    # 准备参数
    batch_size, seq_len, dim = 2, 10, 512
    hidden_dim = 2048
    
    # 初始化模块
    ffn = FeedForward(dim, hidden_dim)
    
    # 准备输入
    x = torch.randn(batch_size, seq_len, dim)
    
    # 前向传播
    output = ffn(x)
    
    # 验证输出
    print("--- FeedForward Test ---")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")

层归一化

import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    """
    层归一化 (Layer Normalization)
    公式: y = (x - mean) / sqrt(var + eps) * gamma + beta
    """
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        # 可学习参数 gamma (缩放) 和 beta (偏移)
        # nn.Parameter 会被自动注册为模型参数
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        # x: [batch_size, seq_len, dim]
        # 在最后一个维度 (dim) 上计算均值和方差
        # keepdim=True 保持维度以便进行广播计算
        mean = x.mean(-1, keepdim=True)
        # unbiased=False 使用有偏估计 (分母为 N),与 PyTorch 默认行为一致
        var = x.var(-1, keepdim=True, unbiased=False)
        
        # 归一化
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        # 缩放和平移
        return self.gamma * x_norm + self.beta

# 单元测试
if __name__ == "__main__":
    # 准备参数
    batch_size, seq_len, dim = 2, 10, 512
    
    # 初始化模块
    ln = LayerNorm(dim)
    
    # 准备输入
    x = torch.randn(batch_size, seq_len, dim)
    
    # 前向传播
    output = ln(x)
    
    # 验证输出
    print("--- LayerNorm Test ---")
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")

组装与运行

核心框架

import torch
import torch.nn as nn
import math
# 导入组件
from .attention import MultiHeadAttention
from .ffn import FeedForward
from .norm import LayerNorm
from .pos import PositionalEncoding

class EncoderLayer(nn.Module):
    def __init__(self, dim, n_heads, hidden_dim, dropout=0.1):
        super().__init__()
        # 多头自注意力子层
        self.attention = MultiHeadAttention(dim, n_heads, dropout)
        self.attention_norm = LayerNorm(dim)
        # 前馈网络子层
        self.feed_forward = FeedForward(dim, hidden_dim, dropout)
        self.ffn_norm = LayerNorm(dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # 子层 1:自注意力
        _x = x
        x = self.attention(x, x, x, mask)  # Q=K=V=x
        x = self.attention_norm(_x + self.dropout(x))
        
        # 子层 2:前馈网络
        _x = x
        x = self.feed_forward(x)
        x = self.ffn_norm(_x + self.dropout(x))
        
        return x

接下来是解码器层,这部分比编码器层多了一个“交叉注意力”子层,先是带掩码的自注意力,再是对编码器输出的交叉注意力,最后是前馈网络。

class DecoderLayer(nn.Module):
    def __init__(self, dim, n_heads, hidden_dim, dropout=0.1):
        super().__init__()
        # 1. 带掩码的自注意力
        self.self_attention = MultiHeadAttention(dim, n_heads, dropout)
        self.self_attn_norm = LayerNorm(dim)
        # 2. 交叉注意力
        self.cross_attention = MultiHeadAttention(dim, n_heads, dropout)
        self.cross_attn_norm = LayerNorm(dim)
        # 3. 前馈网络
        self.feed_forward = FeedForward(dim, hidden_dim, dropout)
        self.ffn_norm = LayerNorm(dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        # 子层 1:带掩码的自注意力
        _x = x
        x = self.self_attention(x, x, x, tgt_mask)
        x = self.self_attn_norm(_x + self.dropout(x))
        
        # 子层 2:交叉注意力(Q 来自解码器,K/V 来自编码器输出)
        _x = x
        x = self.cross_attention(x, enc_output, enc_output, src_mask)
        x = self.cross_attn_norm(_x + self.dropout(x))
        
        # 子层 3:前馈网络
        _x = x
        x = self.feed_forward(x)
        x = self.ffn_norm(_x + self.dropout(x))
        
        return x

最后在 Transformer 主类中,我们需要补全相关的辅助方法。

class Transformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 tgt_vocab_size, 
                 dim=512, 
                 n_heads=8, 
                 n_layers=6, 
                 hidden_dim=2048, 
                 max_seq_len=5000, 
                 dropout=0.1):
        # ... 初始化嵌入层、位置编码、编码器/解码器堆叠以及输出层等 ...
        self._init_parameters()

    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def generate_mask(self, src, tgt):
        # src_mask: [batch, 1, 1, src_len],pad token 假设为 0
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        
        # tgt_mask: [batch, 1, tgt_len, tgt_len],结合 pad mask 和 causal mask
        tgt_len = tgt.size(1)
        tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, tgt_len]
        tgt_subsequent_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
        tgt_mask = tgt_pad_mask & tgt_subsequent_mask.unsqueeze(0)
        return src_mask, tgt_mask

    def encode(self, src, src_mask):
        x = self.src_embedding(src) * math.sqrt(self.dim)
        x = self.pos_encoder(x)
        x = self.dropout(x)
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x

    def decode(self, tgt, enc_output, src_mask, tgt_mask):
        x = self.tgt_embedding(tgt) * math.sqrt(self.dim)
        x = self.pos_encoder(x)
        x = self.dropout(x)
        for layer in self.decoder_layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return x

    # 前向传播
    def forward(self, src, tgt):

主程序代码

import torch
from src.transformer import Transformer

def main():
    # 超参数
    src_vocab_size = 100
    tgt_vocab_size = 100
    dim = 512
    n_heads = 8
    n_layers = 6
    hidden_dim = 2048
    max_seq_len = 50
    dropout = 0.1
    
    # 实例化模型
    model = Transformer(
        src_vocab_size, 
        tgt_vocab_size, 
        dim, 
        n_heads, 
        n_layers, 
        hidden_dim, 
        max_seq_len, 
        dropout
    )
    
    # 模拟输入数据
    batch_size = 2
    src_len = 10
    tgt_len = 12
    
    # 随机生成 src 和 tgt 序列 (假设 pad_token_id=0)
    # 确保没有 pad token 影响简单测试,或者手动插入
    src = torch.randint(1, src_vocab_size, (batch_size, src_len))
    tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))
    
    # 前向传播
    output = model(src, tgt)
    
    print("Model Architecture:")
    # print(model)
    print("\nTest Input:")
    print(f"Source Shape: {src.shape}")
    print(f"Target Shape: {tgt.shape}")
    
    print("\nModel Output:")
    print(f"Output Shape: {output.shape}") # 预期 [batch_size, tgt_len, tgt_vocab_size]

if __name__ == "__main__":
    main()

参考资料

https://github.com/datawhalechina/base-nlp/

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。