图解 Self-Attention(自注意力机制)简单原理与实现

举报
pluto1447 发表于 2025/12/24 21:02:01 2025/12/24
【摘要】 想快速初步理解self-attention?本文带你图解 Self-Attention。通过通俗的类比和清晰的流程图,拆解查询(Q)、键(K)和值(V)之间的互动秘密。我们将探讨它如何通过权重分配实现全局视野,并解决传统模型难以并行计算的痛点。无论你是初学者还是想查漏补缺,这篇“保姆级”指南都不容错过。

图解 Self-Attention(自注意力机制)简单原理与实现

1. 为什么需要 Self-Attention?

在深度学习处理序列数据(如自然语言)时,我们最初依赖 RNN(循环神经网络)。但 RNN 有两个致命缺点:

  1. 长距离依赖问题:信息在序列中传递时会逐渐丢失,难以联系相隔很远的词。
  2. 计算效率低:必须按顺序一个词一个词计算,无法并行化计算。

Self-Attention(自注意力机制) 的出现打破了这种僵局。它允许模型在处理序列中的每一个元素时,都能“同时”观察到序列中的所有其他元素,并根据相关性分配注意力权重。


2. 核心直觉:类比“图书馆检索”(仅供参考)

为了理解 Self-Attention 的工作原理,可以引入三个核心概念:Query (Q), Key (K), 和 Value (V)

想象你在一个图书馆里找书:

  • Query (Q):你手里拿着的“查询搜索词”(我想找关于“老虎”的书籍)。
  • Key (K):书架上每本书的“标签/索引”(书 A 标签是“动物”,书 B 标签是“植物”)。
  • Value (V):书里的“具体内容”。

Self-Attention 的过程就是:

  1. 用你的 Query 去和所有的 Key 做匹配,计算相似度(注意力分数)。
  2. 经过归一化(Softmax),得到一组权重。
  3. 根据权重对所有的 Value 进行加权求和,得到最终的输出。

3. 数学原理:五步走

假设输入序列的嵌入矩阵为 XX,Self-Attention 的计算步骤如下:

第一步:线性变换

通过三个可学习的权重矩阵 WQ,WK,WVW^Q, W^K, W^V,将输入 XX 映射到三个空间:

Q=XWQQ = X \cdot W^Q

K=XWKK = X \cdot W^K

V=XWVV = X \cdot W^V

第二步:计算注意力分数(Attention Score)

计算 QQKK 的点积,衡量序列中词与词之间的相关性:

Score=QKT\text{Score} = Q \cdot K^T

第三步:缩放(Scaling)

为了防止点积结果过大导致 Softmax 梯度消失,除以 dk\sqrt{d_k}dkd_k 是向量维度):

Scaled Score=QKTdk\text{Scaled Score} = \frac{Q \cdot K^T}{\sqrt{d_k}}

第四步:归一化(Softmax)

通过 Softmax 将分数转化为概率分布(总和为 1):

α=softmax(QKTdk)\alpha = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right)

第五步:加权求和

用得到的权重 α\alphaVV 进行加权,提取出重要信息:

Attention(Q,K,V)=αV\text{Attention}(Q, K, V) = \alpha \cdot V


4. 总结公式

将上述步骤合并,就是大名鼎鼎的 Attention 公式:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V


5. Multi-Head Attention:多头注意力

为什么要用“多头”?
如果只有一组 Q,K,VQ, K, V,模型只能学到一种关注方式(比如只关注语法结构)。Multi-Head Attention 相当于让多个人从不同的角度看同一个句子:

  • 第一头:关注词与词的语义关系。
  • 第二头:关注指代关系(如 “it” 指代 “apple”)。
  • 第三头:关注时间/地点。

最后将所有头的输出拼接(Concat)并线性变换即可。


6. PyTorch 极简实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # Einsum does matrix multiplication for query*keys
        # queries shape: (N, query_len, heads, head_dim)
        # keys shape: (N, key_len, heads, head_dim)
        # energy shape: (N, heads, query_len, key_len)
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        # Softmax and Weighted Sum
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhqk,nvhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        return self.fc_out(out)

7. 结语

Self-Attention 的精妙之处在于它彻底改变了模型处理序列的方式。它不再通过“记忆”来传递信息,而是通过“观察”来全局捕捉特征。这也是后来 BERT, GPT, Claude 等大模型能够如此强大的底层基石。

如果你对 Transformer 架构感兴趣,理解 Self-Attention 就是通往现代 NLP 大门的第一把钥匙。


感谢阅读!如果你觉得有帮助,欢迎点赞分享。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。