【Datawhale学习笔记】注意力机制及Transform代码实践
注意力机制
设计原理
在解码器生成每一个词元时,不再依赖一个固定的上下文向量,而是允许它“回头看”一遍完整的输入序列,并根据当前解码的需求,自主地为输入序列的每个部分分配不同的注意力权重,然后基于这些权重将输入信息加权求和,生成一个动态的、专属当前时间步的上下文向量。
通俗地理解为从“一言以蔽之”到“择其要者而观之”的转变
注意力机制详解
三部曲
- 计算相似度
使用解码器上一时刻的隐藏状态 与编码器的每一个隐藏状态 计算一个分数,这个分数衡量了在当前解码时刻,应当对第 个输入词元投入多少“关注”。
这个分数越高,代表关联性越强。计算这个分数的方式有很多种,例如简单的点积、或者引入一个可学习的神经网络层。
- 计算注意力权重
得到输入序列所有位置的注意力分数 后,为了将它们转换成一种“权重”的表示,可使用 Softmax 函数对其进行归一化。这样,就能得到一组总和为 1、且均为正数的注意力权重 。
这组权重 构成了一个概率分布,清晰地表明了在当前解码步骤 ,注意力应该如何分配在输入序列的各个位置上。
- 加权求和,生成上下文向量
最后,使用上一步得到的注意力权重 ,对编码器的所有隐藏状态 进行加权求和,从而得到当前解码时刻 专属的上下文向量 。
这个 向量,由于是根据当前解码需求动态生成的,它比原始 Seq2Seq 的那个固定向量 包含了更具针对性的信息。
PyTorch 实现与代码解析
编码器
class Encoder(nn.Module):
def __init__(self, vocab_size, hidden_size, num_layers):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size
)
self.rnn = nn.LSTM(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True # 使用双向LSTM
)
self.fc = nn.Linear(hidden_size * 2, hidden_size)
def forward(self, x):
embedded = self.embedding(x)
outputs, (hidden, cell) = self.rnn(embedded)
# 将双向RNN的输出通过线性层降维,使其与解码器维度匹配
outputs = torch.tanh(self.fc(outputs))
return outputs, hidden, cell
- bidirectional=True:启用双向 LSTM,使原始 RNN outputs 维度变为 (batch, src_len, hidden_size * 2)。
- self.fc:定义一个线性层,将拼接后的双向输出映射回 hidden_size 维度;经过 self.fc 和 tanh 后,outputs 维度回到 (batch, src_len, hidden_size),方便后续计算。
- return outputs, …:返回降维后的所有时间步输出 outputs (作为后续的 K 和 V),以及原始的最终状态 hidden 和 cell。
注意力模块的两种实现
无参数的注意力
class AttentionSimple(nn.Module):
"""1: 无参数的注意力模块"""
def __init__(self, hidden_size):
super(AttentionSimple, self).__init__()
# 确保缩放因子是一个 non-learnable buffer
self.register_buffer("scale_factor", torch.sqrt(torch.FloatTensor([hidden_size])))
def forward(self, hidden, encoder_outputs):
# hidden shape: (num_layers, batch_size, hidden_size)
# encoder_outputs shape: (batch_size, src_len, hidden_size)
# Q: 解码器最后一层的隐藏状态
query = hidden[-1].unsqueeze(1) # -> (batch, 1, hidden)
# K/V: 编码器的所有输出
keys = encoder_outputs # -> (batch, src_len, hidden)
# energy shape: (batch, 1, src_len)
energy = torch.bmm(query, keys.transpose(1, 2)) / self.scale_factor
# attention_weights shape: (batch, src_len)
return torch.softmax(energy, dim=2).squeeze(1)
带参数的注意力
class AttentionParams(nn.Module):
"""2: 带参数的注意力模块"""
def __init__(self, hidden_size):
super(AttentionParams, self).__init__()
self.attn = nn.Linear(hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.rand(hidden_size))
def forward(self, hidden, encoder_outputs):
src_len = encoder_outputs.shape[1]
hidden_last_layer = hidden[-1].unsqueeze(1).repeat(1, src_len, 1)
energy = torch.tanh(self.attn(torch.cat((hidden_last_layer, encoder_outputs), dim=2)))
attention = torch.sum(self.v * energy, dim=2)
return torch.softmax(attention, dim=1)
通用解码器
class DecoderWithAttention(nn.Module):
def __init__(self, vocab_size, hidden_size, num_layers, attention_module):
super(DecoderWithAttention, self).__init__()
self.attention = attention_module
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size
)
self.rnn = nn.LSTM(
input_size=hidden_size * 2, # 输入维度是 词嵌入(hidden_size) + 上下文向量(hidden_size)
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True
)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden, cell, encoder_outputs):
embedded = self.embedding(x.unsqueeze(1))
# 1. 计算注意力权重
# a shape: [batch, src_len]
a = self.attention(hidden, encoder_outputs).unsqueeze(1)
# 2. 计算上下文向量
context = torch.bmm(a, encoder_outputs)
# 3. 将上下文向量与当前输入拼接
rnn_input = torch.cat((embedded, context), dim=2)
# 4. 传入RNN解码
outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
# 5. 预测输出
predictions = self.fc(outputs.squeeze(1))
return predictions, hidden, cell
Seq2Seq 包装模块
class Seq2Seq(nn.Module):
"""带注意力的Seq2Seq"""
def __init__(self, encoder, decoder, device):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5):
batch_size = src.shape[0]
trg_len = trg.shape[1]
trg_vocab_size = self.decoder.fc.out_features
outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
encoder_outputs, hidden, cell = self.encoder(src)
# 适配Encoder(双向)和Decoder(单向)的状态维度
hidden = hidden.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)
cell = cell.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)
input = trg[:, 0]
for t in range(1, trg_len):
# 在循环的每一步,都将 encoder_outputs 传递给解码器
# 这是 Attention 机制能够"回顾"整个输入序列的关键
output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
outputs[:, t, :] = output
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[:, t] if teacher_force else top1
return outputs
完整代码
import torch
import torch.nn as nn
import random
torch.manual_seed(42)
# 1. 全局配置
batch_size = 8
src_len = 10
trg_len = 12
src_vocab_size = 100
trg_vocab_size = 120
hidden_size = 64
num_layers = 2
sos_idx = 1
eos_idx = 2
# 2. 模型定义
class Encoder(nn.Module):
"""编码器: 读取输入序列,输出所有时间步的特征以及最终状态。"""
def __init__(self, vocab_size, hidden_size, num_layers):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size
)
self.rnn = nn.LSTM(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True # 使用双向LSTM
)
self.fc = nn.Linear(hidden_size * 2, hidden_size)
def forward(self, x):
embedded = self.embedding(x)
outputs, (hidden, cell) = self.rnn(embedded)
outputs = torch.tanh(self.fc(outputs))
return outputs, hidden, cell
class DecoderWithAttention(nn.Module):
"""带注意力的通用解码器"""
def __init__(self, vocab_size, hidden_size, num_layers, attention_module):
super(DecoderWithAttention, self).__init__()
self.attention = attention_module
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size
)
self.rnn = nn.LSTM(
input_size=hidden_size * 2, # 输入维度是 词嵌入 + 上下文向量
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True
)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden, cell, encoder_outputs):
x = x.unsqueeze(1)
embedded = self.embedding(x)
a = self.attention(hidden, encoder_outputs).unsqueeze(1)
context = torch.bmm(a, encoder_outputs)
rnn_input = torch.cat((embedded, context), dim=2)
outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
predictions = self.fc(outputs.squeeze(1))
return predictions, hidden, cell
class Seq2Seq(nn.Module):
"""带注意力的Seq2Seq"""
def __init__(self, encoder, decoder, device):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5):
batch_size = src.shape[0]
trg_len = trg.shape[1]
trg_vocab_size = self.decoder.fc.out_features
outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
encoder_outputs, hidden, cell = self.encoder(src)
hidden = hidden.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)
cell = cell.view(self.encoder.rnn.num_layers, 2, batch_size, -1).sum(dim=1)
input = trg[:, 0]
for t in range(1, trg_len):
output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs)
outputs[:, t, :] = output
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[:, t] if teacher_force else top1
return outputs
def greedy_decode(self, src, max_len=trg_len):
self.eval()
with torch.no_grad():
encoder_outputs, hidden, cell = self.encoder(src)
hidden = hidden.view(self.encoder.rnn.num_layers, 2, src.shape[0], -1).sum(axis=1)
cell = cell.view(self.encoder.rnn.num_layers, 2, src.shape[0], -1).sum(axis=1)
trg_indexes = [sos_idx]
for _ in range(max_len):
input_tensor = torch.LongTensor([trg_indexes[-1]]).to(self.device)
output, hidden, cell = self.decoder(input_tensor, hidden, cell, encoder_outputs)
pred_token = output.argmax(1).item()
trg_indexes.append(pred_token)
if pred_token == eos_idx:
break
return trg_indexes
# 3. Attention 定义
class AttentionSimple(nn.Module):
"""1: 无参数的注意力模块"""
def __init__(self, hidden_size):
super(AttentionSimple, self).__init__()
# 确保缩放因子是一个 non-learnable buffer
self.register_buffer("scale_factor", torch.sqrt(torch.FloatTensor([hidden_size])))
def forward(self, hidden, encoder_outputs):
# hidden shape: (num_layers, batch_size, hidden_size)
# encoder_outputs shape: (batch_size, src_len, hidden_size)
# Q: 解码器最后一层的隐藏状态
query = hidden[-1].unsqueeze(1) # -> (batch, 1, hidden)
# K/V: 编码器的所有输出
keys = encoder_outputs # -> (batch, src_len, hidden)
# energy shape: (batch, 1, src_len)
energy = torch.bmm(query, keys.transpose(1, 2)) / self.scale_factor
# attention_weights shape: (batch, src_len)
return torch.softmax(energy, dim=2).squeeze(1)
class AttentionParams(nn.Module):
"""2: 带参数的注意力模块"""
def __init__(self, hidden_size):
super(AttentionParams, self).__init__()
self.attn = nn.Linear(hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.rand(hidden_size))
def forward(self, hidden, encoder_outputs):
src_len = encoder_outputs.shape[1]
hidden_last_layer = hidden[-1].unsqueeze(1).repeat(1, src_len, 1)
energy = torch.tanh(self.attn(torch.cat((hidden_last_layer, encoder_outputs), dim=2)))
attention = torch.sum(self.v * energy, dim=2)
return torch.softmax(attention, dim=1)
# 4. 主流程
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- 统一创建编码器和伪数据 ---
encoder = Encoder(src_vocab_size, hidden_size, num_layers).to(device)
src = torch.randint(1, src_vocab_size, (batch_size, src_len)).to(device)
trg = torch.randint(1, trg_vocab_size, (batch_size, trg_len)).to(device)
# =========================================
# 1: 无参数的基础 Attention
# =========================================
print("\n" + "="*20 + " 1: 无参数的基础 Attention " + "="*20)
attention_simple = AttentionSimple(hidden_size).to(device)
decoder_simple = DecoderWithAttention(trg_vocab_size, hidden_size, num_layers, attention_simple).to(device)
model_simple = Seq2Seq(encoder, decoder_simple, device).to(device)
model_simple.train()
outputs_simple = model_simple(src, trg)
print(f"模型结构:\n{model_simple}")
print(f"\n训练模式输出张量形状: {outputs_simple.shape}")
prediction_simple = model_simple.greedy_decode(src[0:1, :])
print(f"推理的预测结果: {prediction_simple}")
# =========================================
# 2: 带参数的 Attention
# =========================================
print("\n" + "="*20 + " 2: 带参数的 Attention " + "="*20)
attention_params = AttentionParams(hidden_size).to(device)
decoder_params = DecoderWithAttention(trg_vocab_size, hidden_size, num_layers, attention_params).to(device)
model_params = Seq2Seq(encoder, decoder_params, device).to(device)
model_params.train()
outputs_params = model_params(src, trg)
print(f"模型结构:\n{model_params}")
print(f"\n训练模式输出张量形状: {outputs_params.shape}")
prediction_params = model_params.greedy_decode(src[0:1, :])
print(f"推理的预测结果: {prediction_params}")
注意力机制的类型
Soft Attention vs. Hard Attention
-
Soft Attention:这就是前文一直在详细讨论的机制。它为输入序列的所有位置都计算一个注意力权重,这些权重是 0 到 1 之间的浮点数(经 Softmax 归一化),然后进行加权求和。这种方式的优点是模型是端到端可微的,可以使用标准的梯度下降法进行训练。其缺点是在处理非常长的序列时,计算开销会很大。因为解码的每一步,都需要计算当前状态与所有输入状态的相似度。
-
Hard Attention3:与 Soft Attention 对所有输入进行加权不同,Hard Attention 在每一步只选择一个最相关的输入位置。可以看作是一种“非 0 即 1”的注意力分配,即选中的位置权重为 1,其他所有位置的权重均为 0。这样做的好处是计算量大大减少,因为不再需要进行全面的加权求和。但它的缺点也很突出:选择过程是离散的、不可微的,因此无法使用常规的反向传播算法进行训练,通常需要借助强化学习等更复杂的技巧。
Global Attention vs. Local Attention
-
Global Attention (全局注意力):其思想与 Soft Attention 基本一致,即在计算注意力时,会考虑编码器的所有隐藏状态。
-
Local Attention (局部注意力):这是一种介于 Soft Attention 和 Hard Attention 之间的折中方案。能够减少计算量,但又不像 Hard Attention 那样极端。其核心思想是,在每个解码时间步,只关注输入序列的一个局部窗口。它的工作流程通常是:
-
预测对齐位置:首先,模型需要预测一个当前解码步最关注的源序列位置 。这个位置可以通过一个小型神经网络,仅依赖于当前解码器状态 来预测,从而避免了与所有编码器状态进行比较,降低了计算成本。预测公式可以设计为: ,其中 是源序列长度, 和 是可学习的参数。
-
定义窗口:以预测出的 为中心,定义一个大小为 的窗口,其中 是一个超参数。
-
局部计算:最后,模型只在这个窗口内的编码器状态上应用 Soft Attention 机制,计算权重并生成上下文向量。
Transformer
Transformer的提出
2017年,Google 的研究团队发表了一篇名为《Attention Is All You Need》的论文,提出了一种全新的架构——Transformer
自注意力机制
从根本上说,要让模型理解一段文本,就需要提取其“序列特征”,即将文本中所有词元的信息以某种方式整合起来。RNN 通过依次传递隐藏状态来顺序地整合信息,而 Transformer 则选择了一条截然不同的道路。其核心是 自注意力机制。它不再依赖于顺序计算,而是将提取序列特征的过程看作是输入序列“自己对自己进行注意力计算”。序列中的每个词元都会“审视”序列中的所有其他词元,来动态地计算出最能代表当前词元上下文含义的新表示。与上一节介绍的交叉注意力不同,在自注意力中,Query、Key、Value 均来源于同一个输入序列。
PyTorch 实现自注意力
class SelfAttention(nn.Module):
"""自注意力模块"""
def __init__(self, hidden_size):
super(SelfAttention, self).__init__()
self.hidden_size = hidden_size
self.q_linear = nn.Linear(hidden_size, hidden_size)
self.k_linear = nn.Linear(hidden_size, hidden_size)
self.v_linear = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
q = self.q_linear(x)
k = self.k_linear(x)
v = self.v_linear(x)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.hidden_size)
attention_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attention_weights, v)
return context
- init: 初始化了三个 nn.Linear 层,它们分别对应将输入映射到 Q, K, V 空间的权重矩阵 。
forward:- q_linear(x), k_linear(x), v_linear(x):将形状为 [batch_size, seq_len, hidden_size] 的输入张量 x 分别通过三个线性层,一次性地为序列中的所有词元计算出 Q, K, V 矩阵。
- torch.matmul(q, k.transpose(-2, -1)): 这是实现并行计算的核心。通过将 K 矩阵的最后两个维度转置(seq_len, hidden_size -> hidden_size, seq_len),再与 Q 矩阵相乘,直接得到了一个 [batch_size, seq_len, seq_len] 的分数矩阵。该矩阵中的 scores[b, i, j] 代表了批次 b 中第 i 个词元对第 j 个词元的注意力分数。
- / math.sqrt(self.hidden_size):执行缩放操作,防止梯度消失。
- torch.softmax(scores, dim=-1):对分数的最后一个维度(seq_len)进行 Softmax,得到归一化的注意力权重。
- torch.matmul(attention_weights, v):将权重矩阵与 V 矩阵相乘,完成了对所有词元的 Value 向量的加权求和,得到最终的上下文感知表示。
多头注意力机制
仅仅用一组 矩阵进行一次自注意力计算,相当于只从一个“视角”来审视文本内在的关系。然而,文本中的关系是多层次的,例如,一组参数可能学会了关注代词(如 “它” 指向谁)的关系,但可能忽略了动作的执行者(主谓宾)等其他类型的关系。
为了让模型能够综合利用从不同维度和视角提取出的信息,Transformer 引入了多头注意力机制 (Multi-Head Attention)。其思想非常直接:并行地执行多次自注意力计算,每一次计算都是一个独立的“头 (Head)”。每个头都拥有一组自己专属的 权重矩阵,并且可以学习去关注一种特定类型的上下文关系。
PyTorch 实现多头注意力
class MultiHeadSelfAttention(nn.Module):
"""多头自注意力模块"""
def __init__(self, hidden_size, num_heads):
super(MultiHeadSelfAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_linear = nn.Linear(hidden_size, hidden_size)
self.k_linear = nn.Linear(hidden_size, hidden_size)
self.v_linear = nn.Linear(hidden_size, hidden_size)
self.wo = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
batch_size, seq_len, _ = x.shape
q = self.q_linear(x)
k = self.k_linear(x)
v = self.v_linear(x)
# 拆分多头
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 并行计算注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attention_weights, v)
# 合并多头结果
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# 输出层
output = self.wo(context)
return output
Transformer 整体结构
编码器(Encoder)
编码器的作用是“理解”和“消化”输入的整个序列,为序列中的每个词元生成一个富含上下文信息的表示。一个标准的编码器层由两个主要的子层构成,分别是多头自注意力层(Multi-Head Self-Attention Layer)和位置前馈网络(Position-wise Feed-Forward Network)。每个子层的输出都经过了**残差连接(Add)与层归一化(Norm)**处理。所以,一个编码器层内部的数据流可以表示为 x -> Sublayer1(x) -> Add & Norm -> Sublayer2(…) -> Add & Norm。
关键特性:
- 注意力类型:编码器中的多头注意力层是双向的自注意力。这意味着在计算时,序列中的任何一个词元都可以“看到”序列中的所有其他词元(包括它自己、它前面的和它后面的)。
- 功能:由于其双向性,编码器非常擅长理解完整的输入文本,并为每个词元生成一个深度融合了上下文信息的表示。
- 应用:通过大量堆叠编码器层而构建的模型(Encoder-Only 架构),如 BERT,在文本分类、命名实体识别等自然语言理解(NLU)任务上取得了巨大成功。
解码器(Decoder)
解码器的作用是基于编码器对原始输入的理解,并结合已经生成的部分,来逐个生成下一个词元。为了完成这个更复杂的任务,一个标准的解码器层(Decoder Layer)比编码器层多了一个注意力子层,总共包含三个子层。分别是带掩码的多头自注意力层(Masked Multi-Head Self-Attention Layer)、交叉注意力层(Cross-Attention Layer)和位置前馈网络(Position-wise Feed-Forward Network)。同样,解码器的每个子层也都采用了残差连接和层归一化。
Transformer 代码实践
搭建整体框架
- Embedding 层:将输入的 token ID 转换为连续的向量表示,并加上位置编码以保留序列顺序信息。
- Encoder 堆叠:由 个 EncoderLayer 串联而成,负责深度提取和理解输入序列的特征。
- Decoder 堆叠:由 个 DecoderLayer 串联而成,负责基于 Encoder 的输出逐步生成目标序列。
- Output 层:一个线性层,将解码器的最终输出映射回词表大小,用于计算下一个词的概率分布。
# src/transformer.py
import torch.nn as nn
from .pos import PositionalEncoding # 稍后实现
# ... 导入其他组件
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, dim=512, n_heads=8, n_layers=6, ...):
super().__init__()
self.dim = dim
# 1. 嵌入层与位置编码
# src_embedding: 将源语言序列映射为向量 (Encoder输入)
self.src_embedding = nn.Embedding(src_vocab_size, dim)
# tgt_embedding: 将目标语言序列映射为向量 (Decoder输入)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, dim)
self.pos_encoder = PositionalEncoding(dim, max_seq_len)
self.dropout = nn.Dropout(dropout)
# 2. 编码器与解码器堆叠
# 使用 ModuleList 来存储层列表,支持按索引访问和自动注册参数
self.encoder_layers = nn.ModuleList([
EncoderLayer(dim, n_heads, hidden_dim, dropout) for _ in range(n_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(dim, n_heads, hidden_dim, dropout) for _ in range(n_layers)
])
# 3. 输出头
self.output = nn.Linear(dim, tgt_vocab_size)
def forward(self, src, tgt):
# 1. 生成掩码 (Padding Mask & Causal Mask)
src_mask, tgt_mask = self.generate_mask(src, tgt)
# 2. 编码器前向传播
enc_output = self.encode(src, src_mask)
# 3. 解码器前向传播
dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
# 4. 输出 Logits
return self.output(dec_output)
return logits
实现核心组件
位置编码
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""
正弦位置编码
Transformer 论文中使用固定公式计算位置编码,不涉及可学习参数。
"""
def __init__(self, dim, max_seq_len=5000):
super().__init__()
# 创建一个足够长的 PE 矩阵 [max_seq_len, dim]
pe = torch.zeros(max_seq_len, dim)
# 生成位置索引 [0, 1, ..., max_seq_len-1] -> [max_seq_len, 1]
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
# 计算分母中的 div_term: 10000^(2i/dim) = exp(2i * -log(10000)/dim)
# 这种对数变换的计算方式在数值上更稳定
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
# 填充 PE 矩阵
# 偶数维度用 sin,奇数维度用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 增加 batch 维度: [1, max_seq_len, dim] 以便广播
pe = pe.unsqueeze(0)
# 注册为 buffer
# register_buffer 的作用是告诉 PyTorch:
# 1. 'pe' 是模型状态的一部分,会随模型保存和加载 (state_dict)。
# 2. 'pe' 不是模型参数 (Parameter),优化器更新时不会更新它。
self.register_buffer('pe', pe)
在前向传播中,我们的任务就是将位置编码加到输入的词嵌入上。由于我们预先生成的 pe 矩阵可能比当前的输入序列 x 要长,所以需要根据 x 的实际长度对 pe 进行切片。
class PositionalEncoding(nn.Module):
def __init__(self, dim, max_seq_len=5000):
...
def forward(self, x):
"""
Args:
x: 输入的词嵌入序列 [batch_size, seq_len, dim]
Returns:
加上位置编码后的序列 [batch_size, seq_len, dim]
"""
# 截取与输入序列长度对应的位置编码并相加
# x.size(1) 是 seq_len
# self.pe 的形状是 [1, max_seq_len, dim],切片后会自动广播到 batch_size
x = x + self.pe[:, :x.size(1), :]
return x
最后,我们可以编写一段简单的测试代码来验证维度是否正确。
if __name__ == "__main__":
# 准备参数
batch_size, seq_len, dim = 2, 10, 512
max_seq_len = 100
# 初始化模块
pe = PositionalEncoding(dim, max_seq_len)
# 准备输入
x = torch.zeros(batch_size, seq_len, dim) # 输入为0,直接观察PE值
# 前向传播
output = pe(x)
# 验证输出
print("--- PositionalEncoding Test ---")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
多头注意力
# src/attention.py
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, dim, n_heads, dropout=0.1):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
# 定义 Wq, Wk, Wv 矩阵
# 这里我们使用一个大的线性层一次性计算所有头的 Q, K, V
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, dim)
self.wv = nn.Linear(dim, dim)
# 最终输出的线性层 Wo
self.wo = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
这部分前向传播的重点是“分头”操作。我们不直接对 [batch, seq_len, dim] 进行计算,而是将其 reshape 为 [batch, n_heads, seq_len, head_dim],这样就可以利用矩阵运算并行地处理所有头。
class MultiHeadAttention(nn.Module):
def __init__(self, dim, n_heads, dropout=0.1):
...
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 1. 线性投影
# [batch, seq_len, dim] -> [batch, seq_len, dim]
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
# 2. 分头 (Split Heads)
# 变换形状: [batch, seq_len, n_heads, head_dim]
# 然后转置: [batch, n_heads, seq_len, head_dim] 以便并行计算
q = q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
# 3. 计算缩放点积注意力 (Scaled Dot-Product Attention)
# scores: [batch, n_heads, seq_len, seq_len]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 4. 应用掩码 (Masking)
if mask is not None:
# mask == 0 的位置被填充为负无穷,Softmax 后变为 0
scores = scores.masked_fill(mask == 0, float('-inf'))
# 5. Softmax 与加权求和
attn_weights = torch.softmax(scores, dim=-1)
if self.dropout is not None:
attn_weights = self.dropout(attn_weights)
# context: [batch, n_heads, seq_len, head_dim]
context = torch.matmul(attn_weights, v)
# 6. 合并多头 (Concat Heads)
# [batch, n_heads, seq_len, head_dim] -> [batch, seq_len, n_heads, head_dim]
# -> [batch, seq_len, dim]
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)
# 7. 输出层投影
output = self.wo(context)
return output
# 单元测试
if __name__ == "__main__":
# 准备参数
batch_size, seq_len, dim = 2, 10, 512
n_heads = 8
# 初始化模块
mha = MultiHeadAttention(dim, n_heads)
# 准备输入 (Query, Key, Value 相同)
x = torch.randn(batch_size, seq_len, dim)
# 前向传播
output = mha(x, x, x)
# 验证输出
print("--- MultiHeadAttention Test ---")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
前馈神经网络
# src/ffn.py
import torch.nn as nn
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.1):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim) # 升维
self.w2 = nn.Linear(hidden_dim, dim) # 降维
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 线性变换 -> ReLU -> Dropout -> 线性变换
return self.w2(self.dropout(torch.relu(self.w1(x))))
if __name__ == "__main__":
# 准备参数
batch_size, seq_len, dim = 2, 10, 512
hidden_dim = 2048
# 初始化模块
ffn = FeedForward(dim, hidden_dim)
# 准备输入
x = torch.randn(batch_size, seq_len, dim)
# 前向传播
output = ffn(x)
# 验证输出
print("--- FeedForward Test ---")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
层归一化
import torch
import torch.nn as nn
class LayerNorm(nn.Module):
"""
层归一化 (Layer Normalization)
公式: y = (x - mean) / sqrt(var + eps) * gamma + beta
"""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
# 可学习参数 gamma (缩放) 和 beta (偏移)
# nn.Parameter 会被自动注册为模型参数
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
# x: [batch_size, seq_len, dim]
# 在最后一个维度 (dim) 上计算均值和方差
# keepdim=True 保持维度以便进行广播计算
mean = x.mean(-1, keepdim=True)
# unbiased=False 使用有偏估计 (分母为 N),与 PyTorch 默认行为一致
var = x.var(-1, keepdim=True, unbiased=False)
# 归一化
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# 缩放和平移
return self.gamma * x_norm + self.beta
# 单元测试
if __name__ == "__main__":
# 准备参数
batch_size, seq_len, dim = 2, 10, 512
# 初始化模块
ln = LayerNorm(dim)
# 准备输入
x = torch.randn(batch_size, seq_len, dim)
# 前向传播
output = ln(x)
# 验证输出
print("--- LayerNorm Test ---")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
组装与运行
核心框架
import torch
import torch.nn as nn
import math
# 导入组件
from .attention import MultiHeadAttention
from .ffn import FeedForward
from .norm import LayerNorm
from .pos import PositionalEncoding
class EncoderLayer(nn.Module):
def __init__(self, dim, n_heads, hidden_dim, dropout=0.1):
super().__init__()
# 多头自注意力子层
self.attention = MultiHeadAttention(dim, n_heads, dropout)
self.attention_norm = LayerNorm(dim)
# 前馈网络子层
self.feed_forward = FeedForward(dim, hidden_dim, dropout)
self.ffn_norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 子层 1:自注意力
_x = x
x = self.attention(x, x, x, mask) # Q=K=V=x
x = self.attention_norm(_x + self.dropout(x))
# 子层 2:前馈网络
_x = x
x = self.feed_forward(x)
x = self.ffn_norm(_x + self.dropout(x))
return x
接下来是解码器层,这部分比编码器层多了一个“交叉注意力”子层,先是带掩码的自注意力,再是对编码器输出的交叉注意力,最后是前馈网络。
class DecoderLayer(nn.Module):
def __init__(self, dim, n_heads, hidden_dim, dropout=0.1):
super().__init__()
# 1. 带掩码的自注意力
self.self_attention = MultiHeadAttention(dim, n_heads, dropout)
self.self_attn_norm = LayerNorm(dim)
# 2. 交叉注意力
self.cross_attention = MultiHeadAttention(dim, n_heads, dropout)
self.cross_attn_norm = LayerNorm(dim)
# 3. 前馈网络
self.feed_forward = FeedForward(dim, hidden_dim, dropout)
self.ffn_norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
# 子层 1:带掩码的自注意力
_x = x
x = self.self_attention(x, x, x, tgt_mask)
x = self.self_attn_norm(_x + self.dropout(x))
# 子层 2:交叉注意力(Q 来自解码器,K/V 来自编码器输出)
_x = x
x = self.cross_attention(x, enc_output, enc_output, src_mask)
x = self.cross_attn_norm(_x + self.dropout(x))
# 子层 3:前馈网络
_x = x
x = self.feed_forward(x)
x = self.ffn_norm(_x + self.dropout(x))
return x
最后在 Transformer 主类中,我们需要补全相关的辅助方法。
class Transformer(nn.Module):
def __init__(self,
src_vocab_size,
tgt_vocab_size,
dim=512,
n_heads=8,
n_layers=6,
hidden_dim=2048,
max_seq_len=5000,
dropout=0.1):
# ... 初始化嵌入层、位置编码、编码器/解码器堆叠以及输出层等 ...
self._init_parameters()
def _init_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def generate_mask(self, src, tgt):
# src_mask: [batch, 1, 1, src_len],pad token 假设为 0
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
# tgt_mask: [batch, 1, tgt_len, tgt_len],结合 pad mask 和 causal mask
tgt_len = tgt.size(1)
tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) # [batch, 1, 1, tgt_len]
tgt_subsequent_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
tgt_mask = tgt_pad_mask & tgt_subsequent_mask.unsqueeze(0)
return src_mask, tgt_mask
def encode(self, src, src_mask):
x = self.src_embedding(src) * math.sqrt(self.dim)
x = self.pos_encoder(x)
x = self.dropout(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, enc_output, src_mask, tgt_mask):
x = self.tgt_embedding(tgt) * math.sqrt(self.dim)
x = self.pos_encoder(x)
x = self.dropout(x)
for layer in self.decoder_layers:
x = layer(x, enc_output, src_mask, tgt_mask)
return x
# 前向传播
def forward(self, src, tgt):
主程序代码
import torch
from src.transformer import Transformer
def main():
# 超参数
src_vocab_size = 100
tgt_vocab_size = 100
dim = 512
n_heads = 8
n_layers = 6
hidden_dim = 2048
max_seq_len = 50
dropout = 0.1
# 实例化模型
model = Transformer(
src_vocab_size,
tgt_vocab_size,
dim,
n_heads,
n_layers,
hidden_dim,
max_seq_len,
dropout
)
# 模拟输入数据
batch_size = 2
src_len = 10
tgt_len = 12
# 随机生成 src 和 tgt 序列 (假设 pad_token_id=0)
# 确保没有 pad token 影响简单测试,或者手动插入
src = torch.randint(1, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))
# 前向传播
output = model(src, tgt)
print("Model Architecture:")
# print(model)
print("\nTest Input:")
print(f"Source Shape: {src.shape}")
print(f"Target Shape: {tgt.shape}")
print("\nModel Output:")
print(f"Output Shape: {output.shape}") # 预期 [batch_size, tgt_len, tgt_vocab_size]
if __name__ == "__main__":
main()
参考资料
- 点赞
- 收藏
- 关注作者
评论(0)