图解 Self-Attention(自注意力机制)简单原理与实现
图解 Self-Attention(自注意力机制)简单原理与实现
1. 为什么需要 Self-Attention?
在深度学习处理序列数据(如自然语言)时,我们最初依赖 RNN(循环神经网络)。但 RNN 有两个致命缺点:
- 长距离依赖问题:信息在序列中传递时会逐渐丢失,难以联系相隔很远的词。
- 计算效率低:必须按顺序一个词一个词计算,无法并行化计算。
Self-Attention(自注意力机制) 的出现打破了这种僵局。它允许模型在处理序列中的每一个元素时,都能“同时”观察到序列中的所有其他元素,并根据相关性分配注意力权重。
2. 核心直觉:类比“图书馆检索”(仅供参考)
为了理解 Self-Attention 的工作原理,可以引入三个核心概念:Query (Q), Key (K), 和 Value (V)。
想象你在一个图书馆里找书:
- Query (Q):你手里拿着的“查询搜索词”(我想找关于“老虎”的书籍)。
- Key (K):书架上每本书的“标签/索引”(书 A 标签是“动物”,书 B 标签是“植物”)。
- Value (V):书里的“具体内容”。
Self-Attention 的过程就是:
- 用你的 Query 去和所有的 Key 做匹配,计算相似度(注意力分数)。
- 经过归一化(Softmax),得到一组权重。
- 根据权重对所有的 Value 进行加权求和,得到最终的输出。
3. 数学原理:五步走
假设输入序列的嵌入矩阵为 ,Self-Attention 的计算步骤如下:
第一步:线性变换
通过三个可学习的权重矩阵 ,将输入 映射到三个空间:
第二步:计算注意力分数(Attention Score)
计算 与 的点积,衡量序列中词与词之间的相关性:
第三步:缩放(Scaling)
为了防止点积结果过大导致 Softmax 梯度消失,除以 ( 是向量维度):
第四步:归一化(Softmax)
通过 Softmax 将分数转化为概率分布(总和为 1):
第五步:加权求和
用得到的权重 对 进行加权,提取出重要信息:

4. 总结公式
将上述步骤合并,就是大名鼎鼎的 Attention 公式:
5. Multi-Head Attention:多头注意力
为什么要用“多头”?
如果只有一组 ,模型只能学到一种关注方式(比如只关注语法结构)。Multi-Head Attention 相当于让多个人从不同的角度看同一个句子:
- 第一头:关注词与词的语义关系。
- 第二头:关注指代关系(如 “it” 指代 “apple”)。
- 第三头:关注时间/地点。
最后将所有头的输出拼接(Concat)并线性变换即可。
6. PyTorch 极简实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split embedding into self.heads pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# Einsum does matrix multiplication for query*keys
# queries shape: (N, query_len, heads, head_dim)
# keys shape: (N, key_len, heads, head_dim)
# energy shape: (N, heads, query_len, key_len)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# Softmax and Weighted Sum
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhqk,nvhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
return self.fc_out(out)
7. 结语
Self-Attention 的精妙之处在于它彻底改变了模型处理序列的方式。它不再通过“记忆”来传递信息,而是通过“观察”来全局捕捉特征。这也是后来 BERT, GPT, Claude 等大模型能够如此强大的底层基石。
如果你对 Transformer 架构感兴趣,理解 Self-Attention 就是通往现代 NLP 大门的第一把钥匙。
感谢阅读!如果你觉得有帮助,欢迎点赞分享。
- 点赞
- 收藏
- 关注作者
评论(0)