注意力机制的“瘦身”革命:从多头自注意力到线性复杂度变体的演进

举报
i-WIFI 发表于 2026/01/24 14:18:10 2026/01/24
【摘要】 在深度学习的发展史上,Transformer 的出现无疑是一座分水岭。它抛弃了循环神经网络(RNN)逐步处理的范式,利用自注意力机制实现了真正的并行化计算。然而,当我们试图将 Transformer 从 NLP 领域迁移到长序列任务(如高分辨率图像生成、长文档建模、基因组分析)时,一个被称为“二次方魔咒”的问题横亘在面前——自注意力的计算复杂度是序列长度的平方 O(N2)O(N^2)O(N...

在深度学习的发展史上,Transformer 的出现无疑是一座分水岭。它抛弃了循环神经网络(RNN)逐步处理的范式,利用自注意力机制实现了真正的并行化计算。然而,当我们试图将 Transformer 从 NLP 领域迁移到长序列任务(如高分辨率图像生成、长文档建模、基因组分析)时,一个被称为“二次方魔咒”的问题横亘在面前——自注意力的计算复杂度是序列长度的平方 O(N2)O(N^2)
这就意味着,当序列长度翻倍时,显存占用和计算量会变成原来的四倍。面对这个瓶颈,算法工程师们开始了一场针对注意力机制的“瘦身”运动。在这篇文章中,我将结合实战经验,深入剖析如何通过稀疏注意力以及 Linformer/Performer 等 Transformer 变体,打破计算复杂度的枷锁。

一、 基石:多头自注意力与位置编码

在进入优化之前,我们需要先回顾一下标准的 Transformer 架构。它的核心在于多头自注意力
标准的 MSRA 将输入向量映射为 Query(QQ)、Key(KK)、Value(VV)三个矩阵。注意力的本质是计算 Query 和 Key 的相似度,并以此为权重聚合 Value。

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V

这里有一个极易被忽视但至关重要的组件:位置编码。由于自注意力机制本质上是集合运算,它对输入顺序是不敏感的。如果不注入位置信息,“我打你”和“你打我”在模型眼里是一模一样的。
在早期的 BERT 和 GPT 中,我们使用正弦/余弦函数来生成绝对位置编码。但在长序列任务中,绝对位置编码有一个明显的缺陷:外推性差。当测试序列长度超过训练序列最大长度时,模型的表现会断崖式下跌。
为了解决这个问题,我们转而使用相对位置编码(如 T5 使用的 Bias)或 RoPE(旋转位置编码)。RoPE 通过复数域的旋转操作将位置信息注入到 QQKK 的点积中,不仅计算优雅,而且在长距离外推上表现出了惊人的鲁棒性。这在我们的长文本生成项目中起到了定海神针的作用。

二、 瓶颈:二次方复杂度的代价

为什么 O(N2)O(N^2) 这么可怕?
假设我们处理一个长度为 4096 的序列(这在 NLP 中很常见,在视频处理中算短的)。Attention Score 矩阵的大小就是 4096×40964096 \times 4096。如果我们将序列增加到 16K(处理高分辨率图像 Patch 所需),这个矩阵会变成 256M256 \text{M} 个元素。即便使用半精度(FP16),仅仅存储这个注意力矩阵就需要 512MB 显存!更别提还要对其进行 Softmax 和矩阵乘法运算。
这就是我们迫切需要优化的原因。

三、 第一阶段:稀疏注意力

既然全连接的图太稠密,那能不能把它变稀疏?稀疏注意力的直观思路是:每个 Token 不需要关注所有 Token,只需要关注最相关的那些。
最经典的实现是 Longformer 和 BigBird 提出的滑动窗口模式

  1. 局部窗口:每个 Token 关注左右 ww 个邻居。
  2. 全局 Token:几个特定的 Token(如 [CLS])关注所有 Token,并被所有 Token 关注。
    这种模式将计算复杂度降低到了 O(Nw)O(N \cdot w),甚至线性 O(N)O(N)
    在代码实现中,我们可以通过自定义 Mask 来实现这种稀疏化:
import torch
import torch.nn.functional as F
def sliding_window_attention(q, k, v, window_size, mask=None):
    """
    实现滑动窗口注意力(简化版,未考虑 Padding)
    q, k, v: [batch_size, num_heads, seq_len, head_dim]
    """
    batch_size, num_heads, seq_len, head_dim = q.shape
    
    # 1. 计算标准注意力分数
    scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5) # [bs, heads, seq_len, seq_len]
    
    # 2. 构造滑动窗口 Mask
    # 创建一个 [seq_len, seq_len] 的下三角矩阵,距离大于 window_size 的设为 -inf
    indices = torch.arange(seq_len)
    window_mask = torch.abs(indices.unsqueeze(0) - indices.unsqueeze(1)) > window_size
    scores = scores.masked_fill(window_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
    
    # 3. 计算 Softmax 和输出
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v)
    
    return output

这种方法的缺点是损失了部分全局信息。虽然可以通过 dilation(膨胀窗口)来扩大感受野,但在需要极长距离依赖的任务中,它依然显得力不从心。

四、 第二阶段:Transformer 变体——Linformer

稀疏注意力是通过“扔掉”一些连接来降低复杂度,而 Linformer 则是从数学层面进行降维打击。
Linformer 的核心假设是:在低维空间中,Self-Attention 矩阵具有低秩特性
回顾注意力公式,瓶颈在于 QKTQK^T 这个 N×NN \times N 的矩阵。Linformer 引入了两个投影矩阵 EEFF,将 Key (KK) 和 Value (VV 的维度从 NN 降维到 kk(其中 kNk \ll N)。

Attention(Q,K,V)Q(EKTdk)1(FV)T\text{Attention}(Q, K, V) \approx Q \left( \frac{E K^T}{\sqrt{d_k}} \right)^{-1} (F V)^T

等等,这中间的逆矩阵不好算。Linformer 实际上更直接地修改了结构:它不再计算 N×NN \times N 的分数矩阵,而是先对 KKVV 进行线性投影降维,再计算 N×kN \times k 的分数矩阵。
这使得复杂度直接降为 O(Nk)O(Nk)。当 kk 是一个常数(比如 256)时,这就是线性复杂度。
表:稀疏注意力与 Linformer 的对比

特性 稀疏注意力 Linformer
核心思想 结构化剪枝(只关注局部/全局) 低秩近似(线性投影)
计算复杂度 O(Nw)O(N \cdot w) (w为窗口大小) O(Nk)O(N \cdot k) (k为降维后维度)
显存占用 低(只需存储局部连接) 低(K,VK, V 矩阵变小)
全局信息 较弱(依赖全局 Token) 较强(保留了所有 Token 的压缩信息)
适用场景 长文档、时间序列 高分辨率图像、海量数据预训练
在实际使用 Linformer 时,我们发现它在预训练阶段收敛速度略慢于标准 Transformer,但在推理阶段速度提升巨大,特别是在序列长度超过 8K 以后。

五、 第三阶段:Transformer 变体——Performer

如果说 Linformer 是线性的,那么 Performer 就是线性的“随机化”大师。它利用了随机傅里叶特征正交随机特征来近似 Softmax 操作。
这是极客味道最浓的一种优化。Softmax 运算中的指数项 exp(qk)\exp(q \cdot k) 可以被理解为核函数。Performer 证明了我们可以通过特征映射 ϕ(q)\phi(q)ϕ(k)\phi(k),将点积的核运算转化为线性空间的点积:

exp(qk)ϕ(q)Tϕ(k)\exp(q \cdot k) \approx \phi(q)^T \phi(k)

一旦做了这个变换,注意力的计算顺序就变了:

AttentionD1(ϕ(Q)(ϕ(K)TV))\text{Attention} \approx D^{-1} (\phi(Q) (\phi(K)^T V))

注意括号的位置!我们先计算 ϕ(K)TV\phi(K)^T V,这一步与序列长度 NN 无关(取决于 Key 维度)。然后再与 ϕ(Q)\phi(Q) 相乘。整个过程的复杂度完美地变成了 O(N)O(N)
Performer 的代码实现非常复杂,涉及到正交随机特征的重参数化技巧,但在 PyTorch 中可以借助 einops 库进行简化描述:

# 概念伪代码:展示 Performer 的核心变换逻辑
import torch
import einops
def performer_attention(q, k, v, projection_dim=256):
    """
    q, k, v: [batch, seq_len, dim]
    """
    # 1. 随机特征映射 (简化版,实际使用正交化随机矩阵)
    # 这里的 projection 随机生成且固定
    projection = torch.randn(k.shape[-1], projection_dim, device=k.device)
    
    # F(x) 映射: 这里省略了 Performer 中复杂的 exp 及归一化细节
    q_prime = torch.exp(torch.matmul(q, projection))
    k_prime = torch.exp(torch.matmul(k, projection))
    
    # 2. 变换矩阵乘法顺序
    # 标准: Q @ K.T @ V -> (N, N) @ V -> O(N^2)
    # Performer: (K.T @ V) 先算 -> (dim, N) @ (N, dim) -> O(N * dim^2)
    k_t_v = torch.matmul(k_prime.transpose(-2, -1), v)
    
    # 3. 归一化因子 D 的计算 (省略细节)
    # ...
    
    # 4. 最终聚合
    output = torch.matmul(q_prime, k_t_v)
    return output

Performer 的最大优势在于它是无偏估计,且不需要像 Linformer 那样训练额外的投影矩阵(虽然也可以训练)。在我们的音频处理任务中,Performer 展现出了极快的推理速度,且几乎不损失精度。

六、 总结与选型建议

从标准的多头自注意力到 Linformer 和 Performer,我们走过的每一步都是在精度与效率之间做权衡。
作为算法工程师,在选型时我有以下几点建议:

  1. 短序列(< 1024):老老实实用标准 Transformer + RoPE,不要折腾,性能最好。
  2. 中等长序列(1K - 8K):首选 稀疏注意力(如 Longformer)。结构清晰,可解释性强,很多主流推理引擎(如 ONNX Runtime)对稀疏算子有专门优化。
  3. 超长序列(> 8K):必须上 LinformerPerformer
    • 如果你希望复现现有的 Transformer 预训练权重,或者希望模型结构尽可能简单,选 Linformer(因为它在结构上改动较小,主要是降维)。
    • 如果你追求极致的理论性能下限和线性复杂度的数学美感,选 Performer
      注意力的优化之路依然在继续,比如 FlashAttention 通过硬件感知的 IO 精确计算进一步提升了显存利用率。但在架构层面,理解这几种变体的数学原理,将帮助你设计出应对万物互联时代的下一代 AI 模型。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。