注意力机制的“瘦身”革命:从多头自注意力到线性复杂度变体的演进
在深度学习的发展史上,Transformer 的出现无疑是一座分水岭。它抛弃了循环神经网络(RNN)逐步处理的范式,利用自注意力机制实现了真正的并行化计算。然而,当我们试图将 Transformer 从 NLP 领域迁移到长序列任务(如高分辨率图像生成、长文档建模、基因组分析)时,一个被称为“二次方魔咒”的问题横亘在面前——自注意力的计算复杂度是序列长度的平方 。
这就意味着,当序列长度翻倍时,显存占用和计算量会变成原来的四倍。面对这个瓶颈,算法工程师们开始了一场针对注意力机制的“瘦身”运动。在这篇文章中,我将结合实战经验,深入剖析如何通过稀疏注意力以及 Linformer/Performer 等 Transformer 变体,打破计算复杂度的枷锁。
一、 基石:多头自注意力与位置编码
在进入优化之前,我们需要先回顾一下标准的 Transformer 架构。它的核心在于多头自注意力。
标准的 MSRA 将输入向量映射为 Query()、Key()、Value()三个矩阵。注意力的本质是计算 Query 和 Key 的相似度,并以此为权重聚合 Value。
这里有一个极易被忽视但至关重要的组件:位置编码。由于自注意力机制本质上是集合运算,它对输入顺序是不敏感的。如果不注入位置信息,“我打你”和“你打我”在模型眼里是一模一样的。
在早期的 BERT 和 GPT 中,我们使用正弦/余弦函数来生成绝对位置编码。但在长序列任务中,绝对位置编码有一个明显的缺陷:外推性差。当测试序列长度超过训练序列最大长度时,模型的表现会断崖式下跌。
为了解决这个问题,我们转而使用相对位置编码(如 T5 使用的 Bias)或 RoPE(旋转位置编码)。RoPE 通过复数域的旋转操作将位置信息注入到 和 的点积中,不仅计算优雅,而且在长距离外推上表现出了惊人的鲁棒性。这在我们的长文本生成项目中起到了定海神针的作用。
二、 瓶颈:二次方复杂度的代价
为什么 这么可怕?
假设我们处理一个长度为 4096 的序列(这在 NLP 中很常见,在视频处理中算短的)。Attention Score 矩阵的大小就是 。如果我们将序列增加到 16K(处理高分辨率图像 Patch 所需),这个矩阵会变成 个元素。即便使用半精度(FP16),仅仅存储这个注意力矩阵就需要 512MB 显存!更别提还要对其进行 Softmax 和矩阵乘法运算。
这就是我们迫切需要优化的原因。
三、 第一阶段:稀疏注意力
既然全连接的图太稠密,那能不能把它变稀疏?稀疏注意力的直观思路是:每个 Token 不需要关注所有 Token,只需要关注最相关的那些。
最经典的实现是 Longformer 和 BigBird 提出的滑动窗口模式:
- 局部窗口:每个 Token 关注左右 个邻居。
- 全局 Token:几个特定的 Token(如 [CLS])关注所有 Token,并被所有 Token 关注。
这种模式将计算复杂度降低到了 ,甚至线性 。
在代码实现中,我们可以通过自定义 Mask 来实现这种稀疏化:
import torch
import torch.nn.functional as F
def sliding_window_attention(q, k, v, window_size, mask=None):
"""
实现滑动窗口注意力(简化版,未考虑 Padding)
q, k, v: [batch_size, num_heads, seq_len, head_dim]
"""
batch_size, num_heads, seq_len, head_dim = q.shape
# 1. 计算标准注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5) # [bs, heads, seq_len, seq_len]
# 2. 构造滑动窗口 Mask
# 创建一个 [seq_len, seq_len] 的下三角矩阵,距离大于 window_size 的设为 -inf
indices = torch.arange(seq_len)
window_mask = torch.abs(indices.unsqueeze(0) - indices.unsqueeze(1)) > window_size
scores = scores.masked_fill(window_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
# 3. 计算 Softmax 和输出
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
return output
这种方法的缺点是损失了部分全局信息。虽然可以通过 dilation(膨胀窗口)来扩大感受野,但在需要极长距离依赖的任务中,它依然显得力不从心。
四、 第二阶段:Transformer 变体——Linformer
稀疏注意力是通过“扔掉”一些连接来降低复杂度,而 Linformer 则是从数学层面进行降维打击。
Linformer 的核心假设是:在低维空间中,Self-Attention 矩阵具有低秩特性。
回顾注意力公式,瓶颈在于 这个 的矩阵。Linformer 引入了两个投影矩阵 和 ,将 Key () 和 Value ( 的维度从 降维到 (其中 )。
等等,这中间的逆矩阵不好算。Linformer 实际上更直接地修改了结构:它不再计算 的分数矩阵,而是先对 和 进行线性投影降维,再计算 的分数矩阵。
这使得复杂度直接降为 。当 是一个常数(比如 256)时,这就是线性复杂度。
表:稀疏注意力与 Linformer 的对比
| 特性 | 稀疏注意力 | Linformer |
|---|---|---|
| 核心思想 | 结构化剪枝(只关注局部/全局) | 低秩近似(线性投影) |
| 计算复杂度 | (w为窗口大小) | (k为降维后维度) |
| 显存占用 | 低(只需存储局部连接) | 低( 矩阵变小) |
| 全局信息 | 较弱(依赖全局 Token) | 较强(保留了所有 Token 的压缩信息) |
| 适用场景 | 长文档、时间序列 | 高分辨率图像、海量数据预训练 |
| 在实际使用 Linformer 时,我们发现它在预训练阶段收敛速度略慢于标准 Transformer,但在推理阶段速度提升巨大,特别是在序列长度超过 8K 以后。 |
五、 第三阶段:Transformer 变体——Performer
如果说 Linformer 是线性的,那么 Performer 就是线性的“随机化”大师。它利用了随机傅里叶特征和正交随机特征来近似 Softmax 操作。
这是极客味道最浓的一种优化。Softmax 运算中的指数项 可以被理解为核函数。Performer 证明了我们可以通过特征映射 和 ,将点积的核运算转化为线性空间的点积:
一旦做了这个变换,注意力的计算顺序就变了:
注意括号的位置!我们先计算 ,这一步与序列长度 无关(取决于 Key 维度)。然后再与 相乘。整个过程的复杂度完美地变成了 。
Performer 的代码实现非常复杂,涉及到正交随机特征的重参数化技巧,但在 PyTorch 中可以借助 einops 库进行简化描述:
# 概念伪代码:展示 Performer 的核心变换逻辑
import torch
import einops
def performer_attention(q, k, v, projection_dim=256):
"""
q, k, v: [batch, seq_len, dim]
"""
# 1. 随机特征映射 (简化版,实际使用正交化随机矩阵)
# 这里的 projection 随机生成且固定
projection = torch.randn(k.shape[-1], projection_dim, device=k.device)
# F(x) 映射: 这里省略了 Performer 中复杂的 exp 及归一化细节
q_prime = torch.exp(torch.matmul(q, projection))
k_prime = torch.exp(torch.matmul(k, projection))
# 2. 变换矩阵乘法顺序
# 标准: Q @ K.T @ V -> (N, N) @ V -> O(N^2)
# Performer: (K.T @ V) 先算 -> (dim, N) @ (N, dim) -> O(N * dim^2)
k_t_v = torch.matmul(k_prime.transpose(-2, -1), v)
# 3. 归一化因子 D 的计算 (省略细节)
# ...
# 4. 最终聚合
output = torch.matmul(q_prime, k_t_v)
return output
Performer 的最大优势在于它是无偏估计,且不需要像 Linformer 那样训练额外的投影矩阵(虽然也可以训练)。在我们的音频处理任务中,Performer 展现出了极快的推理速度,且几乎不损失精度。
六、 总结与选型建议
从标准的多头自注意力到 Linformer 和 Performer,我们走过的每一步都是在精度与效率之间做权衡。
作为算法工程师,在选型时我有以下几点建议:
- 短序列(< 1024):老老实实用标准 Transformer + RoPE,不要折腾,性能最好。
- 中等长序列(1K - 8K):首选 稀疏注意力(如 Longformer)。结构清晰,可解释性强,很多主流推理引擎(如 ONNX Runtime)对稀疏算子有专门优化。
- 超长序列(> 8K):必须上 Linformer 或 Performer。
- 如果你希望复现现有的 Transformer 预训练权重,或者希望模型结构尽可能简单,选 Linformer(因为它在结构上改动较小,主要是降维)。
- 如果你追求极致的理论性能下限和线性复杂度的数学美感,选 Performer。
注意力的优化之路依然在继续,比如 FlashAttention 通过硬件感知的 IO 精确计算进一步提升了显存利用率。但在架构层面,理解这几种变体的数学原理,将帮助你设计出应对万物互联时代的下一代 AI 模型。
- 点赞
- 收藏
- 关注作者
评论(0)