Transformer架构的简要解析
Transformer架构自2017年诞生以来,已经彻底革新了人工智能领域,从最初的机器翻译任务扩展到几乎所有的序列建模问题。这种架构通过纯注意力机制取代了传统的循环和卷积结构,实现了前所未有的并行化能力和长距离依赖建模能力。其核心创新在于自注意力机制能够让序列中的任意两个位置直接交互,打破了RNN的序列处理瓶颈。从GPT到BERT,从ChatGPT到Claude,几乎所有现代大语言模型都建立在Transformer的基础之上,这种架构已经成为深度学习时代最重要的技术突破之一。
历史背景与革命性动机
在2017年之前,序列到序列的建模主要依赖于循环神经网络(RNN)和长短期记忆网络(LSTM),这些模型虽然在许多任务上取得了成功,但存在着根本性的架构限制。RNN必须按顺序处理序列,前一个时间步的计算必须完成后才能进行下一步,这种线性依赖关系使得模型无法充分利用现代GPU的并行计算能力。更严重的是,当序列长度增加时,梯度消失和梯度爆炸问题使得RNN难以捕捉长距离依赖关系。即使是专门设计来解决这个问题的LSTM,在处理超过200-500个词的序列时仍然力不从心。
Vaswani等人在"Attention is All You Need"这篇开创性论文中提出了一个大胆的想法:完全抛弃循环和卷积结构,仅使用注意力机制来构建序列模型。这个想法的核心洞察是,序列中任意两个位置之间的交互不应该受到它们距离的限制。在传统RNN中,相距n n n 个位置的两个词需要通过O ( n ) O(n) O ( n ) 次操作才能交互,而在Transformer中,这种交互只需要O ( 1 ) O(1) O ( 1 ) 次操作就能完成。这种架构在WMT 2014英德翻译任务上达到了28.4 BLEU分数,不仅超越了所有现有模型,训练速度还快了一个数量级。
这种架构革新的意义远超技术层面。它证明了归纳偏置(inductive bias)并非模型成功的必要条件,纯粹的注意力机制配合足够的数据和计算资源,可以学习到比精心设计的架构更好的表示。这个发现直接促成了后续大规模预训练模型的爆发,因为Transformer的可扩展性使得训练包含数千亿参数的模型成为可能。
自注意力机制的数学原理与深度剖析
自注意力机制是Transformer的核心创新,它通过计算序列中每个位置与所有其他位置的相关性来构建上下文表示。这个机制的数学优雅性在于它将复杂的序列建模问题转化为矩阵运算,从而充分利用现代硬件的并行计算能力。
给定一个输入序列X = [ x 1 , x 2 , . . . , x n ] X = [x_1, x_2, ..., x_n] X = [ x 1 , x 2 , . . . , x n ] ,其中每个x i ∈ R d m o d e l x_i \in \mathbb{R}^{d_{model}} x i ∈ R d m o d e l 是d d d 维的向量表示,自注意力机制首先通过三个不同的线性变换将每个输入向量映射到查询(Query)、键(Key)和值(Value)空间:
q i = x i W Q , k i = x i W K , v i = x i W V q_i = x_i W^Q, \quad k_i = x_i W^K, \quad v_i = x_i W^V
q i = x i W Q , k i = x i W K , v i = x i W V
这里W Q , W K ∈ R d m o d e l × d k W^Q, W^K \in \mathbb{R}^{d_{model} \times d_k} W Q , W K ∈ R d m o d e l × d k ,W V ∈ R d m o d e l × d v W^V \in \mathbb{R}^{d_{model} \times d_v} W V ∈ R d m o d e l × d v 是可学习的参数矩阵。这种三重映射的设计让模型能够学习输入的不同方面:查询向量代表"我在寻找什么信息",键向量代表"我能提供什么信息",值向量代表"实际传递的信息内容"。
注意力分数的计算采用缩放点积注意力(Scaled Dot-Product Attention):
e i j = q i T k j d k e_{ij} = \frac{q_i^T k_j}{\sqrt{d_k}}
e i j = d k q i T k j
这个公式中的缩放因子d k \sqrt{d_k} d k 至关重要。当d k d_k d k 较大时,点积的结果会变得很大,导致softmax函数进入梯度极小的饱和区域。通过除以d k \sqrt{d_k} d k ,我们将点积的方差控制在1左右,确保梯度能够有效传播。这个看似简单的技巧实际上是训练深层Transformer的关键。
接下来,通过softmax归一化将注意力分数转换为概率分布:
α i j = exp ( e i j ) ∑ l = 1 n exp ( e i l ) \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{l=1}^n \exp(e_{il})}
α i j = ∑ l = 1 n exp ( e i l ) exp ( e i j )
最终,每个位置的输出是所有值向量的加权和:
h i = ∑ j = 1 n α i j v j h_i = \sum_{j=1}^n \alpha_{ij} v_j
h i = j = 1 ∑ n α i j v j
整个操作可以用矩阵形式简洁表达:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
Attention ( Q , K , V ) = softmax ( d k Q K T ) V
这种矩阵化的表达不仅在数学上优雅,更重要的是它能够充分利用现代GPU的矩阵运算能力。整个注意力计算过程中,所有位置的计算都是独立的,可以完全并行化,这是Transformer相比RNN的根本优势。
多头注意力的设计哲学与计算细节
多头注意力机制将单一的注意力扩展到多个表示子空间,这种设计源于一个关键观察:不同类型的语言关系需要不同的表示空间。例如,在处理"The cat sat on the mat"这个句子时,一个注意力头可能关注语法关系(主谓宾结构),另一个头可能关注语义关系(动物与位置的关系),还有的头可能关注局部的词序关系。
多头注意力的数学定义为:
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O
其中每个注意力头计算为:
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
head i = Attention ( Q W i Q , K W i K , V W i V )
这里的参数设置体现了精妙的设计平衡。原始Transformer使用h = 8 h=8 h = 8 个注意力头,每个头的维度d k = d v = d m o d e l / h = 64 d_k = d_v = d_{model}/h = 64 d k = d v = d m o d e l / h = 6 4 (当d m o d e l = 512 d_{model} = 512 d m o d e l = 5 1 2 时)。这种维度分割确保了多头注意力的计算复杂度与单头注意力保持一致,同时显著增强了模型的表达能力。每个头都有独立的投影矩阵W i Q , W i K , W i V ∈ R d m o d e l × d k W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d_{model} \times d_k} W i Q , W i K , W i V ∈ R d m o d e l × d k ,使得不同的头能够学习到互补的注意力模式。
多头机制的实际效果令人惊叹。研究表明,不同的注意力头确实学到了不同类型的语言现象:有的头专注于捕捉位置信息,有的头学会了识别语法依存关系,还有的头能够追踪指代消解。这种自发的功能分化证明了多头设计的有效性,模型能够自动发现并编码多种语言规律,而无需显式的监督信号。
输出投影矩阵W O ∈ R h d v × d m o d e l W^O \in \mathbb{R}^{hd_v \times d_{model}} W O ∈ R h d v × d m o d e l 的作用不仅是维度转换,更重要的是它允许不同注意力头的信息进行交互和整合。这个线性变换学习如何最优地组合来自不同子空间的信息,形成统一的表示。
位置编码的数学基础与几何直觉
自注意力机制的一个根本特性是置换等变性(permutation equivariance),即如果我们改变输入序列的顺序,输出也会以相同的方式改变。这个特性虽然在某些场景下是优势,但对于自然语言处理来说却是致命的缺陷,因为词序对语义至关重要。"猫坐在垫子上"和"垫子坐在猫上"表达的是完全不同的意思。
Transformer采用正弦位置编码来注入位置信息,其数学形式为:
P E ( p o s , 2 i ) = sin ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)
P E ( p o s , 2 i ) = sin ( 1 0 0 0 0 2 i / d m o d e l p o s )
P E ( p o s , 2 i + 1 ) = cos ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)
P E ( p o s , 2 i + 1 ) = cos ( 1 0 0 0 0 2 i / d m o d e l p o s )
这个看似复杂的公式背后有着深刻的数学直觉。每个维度对应一个不同频率的正弦波,频率从2 π 2\pi 2 π 到10000 ⋅ 2 π 10000 \cdot 2\pi 1 0 0 0 0 ⋅ 2 π 呈几何级数分布。低频分量编码全局位置信息,高频分量编码局部位置差异。这种多尺度的频率设计使得模型能够同时感知绝对位置和相对位置。
更重要的是,正弦函数具有一个关键的数学性质:对于任意固定的偏移量k k k ,P E ( p o s + k ) PE(pos+k) P E ( p o s + k ) 可以表示为P E ( p o s ) PE(pos) P E ( p o s ) 的线性函数。具体来说,存在一个旋转矩阵M M M 使得:
M [ sin ( ω p o s ) cos ( ω p o s ) ] = [ sin ( ω ( p o s + k ) ) cos ( ω ( p o s + k ) ) ] M\begin{bmatrix}\sin(\omega pos) \\ \cos(\omega pos)\end{bmatrix} = \begin{bmatrix}\sin(\omega(pos+k)) \\ \cos(\omega(pos+k))\end{bmatrix}
M [ sin ( ω p o s ) cos ( ω p o s ) ] = [ sin ( ω ( p o s + k ) ) cos ( ω ( p o s + k ) ) ]
这个性质意味着相对位置关系可以通过线性变换学习,这对于许多NLP任务至关重要。例如,在语法分析中,词与词之间的依存关系往往取决于它们的相对位置而非绝对位置。
位置编码通过简单的加法与词嵌入结合:
Input = TokenEmbedding ( x ) + PositionalEncoding ( p o s ) \text{Input} = \text{TokenEmbedding}(x) + \text{PositionalEncoding}(pos)
Input = TokenEmbedding ( x ) + PositionalEncoding ( p o s )
这种加法操作保持了模型的简洁性,同时确保位置信息能够影响后续的所有计算。相比于可学习的位置嵌入,正弦编码的优势在于它能够自然地泛化到训练时未见过的序列长度,这对于处理变长序列至关重要。
编码器的精密架构设计
Transformer编码器由N N N 个相同的层堆叠而成(原始论文中N = 6 N=6 N = 6 ),每层包含两个子层:多头自注意力机制和逐位置的前馈网络。这种设计体现了深度与宽度的平衡,既通过堆叠获得深层表示能力,又通过并行处理保持计算效率。
每个编码器层的数据流可以表示为:
首先是自注意力子层,它让每个位置都能关注到输入序列的所有位置,产生包含全局上下文信息的表示。注意力输出与输入通过残差连接相加,然后进行层归一化:
X ′ = LayerNorm ( X + MultiHeadAttention ( X , X , X ) ) X' = \text{LayerNorm}(X + \text{MultiHeadAttention}(X, X, X))
X ′ = LayerNorm ( X + MultiHeadAttention ( X , X , X ) )
接下来是前馈网络子层,对每个位置独立地应用两层全连接网络:
FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2
其中W 1 ∈ R d m o d e l × d f f W_1 \in \mathbb{R}^{d_{model} \times d_{ff}} W 1 ∈ R d m o d e l × d f f ,W 2 ∈ R d f f × d m o d e l W_2 \in \mathbb{R}^{d_{ff} \times d_{model}} W 2 ∈ R d f f × d m o d e l ,通常d f f = 4 ⋅ d m o d e l d_{ff} = 4 \cdot d_{model} d f f = 4 ⋅ d m o d e l 。这个4倍的扩展提供了必要的非线性变换能力,让模型能够学习复杂的特征交互。前馈网络的输出同样通过残差连接和层归一化:
Output = LayerNorm ( X ′ + FFN ( X ′ ) ) \text{Output} = \text{LayerNorm}(X' + \text{FFN}(X'))
Output = LayerNorm ( X ′ + FFN ( X ′ ) )
残差连接的重要性不能被低估。它们提供了从输入到输出的直接路径,使得梯度能够不受阻碍地反向传播。数学上,如果我们将第l l l 层的输出表示为x l x_l x l ,残差连接确保:
x l = x l − 1 + F ( x l − 1 ) x_l = x_{l-1} + F(x_{l-1})
x l = x l − 1 + F ( x l − 1 )
这意味着梯度∂ L ∂ x l − 1 \frac{\partial \mathcal{L}}{\partial x_{l-1}} ∂ x l − 1 ∂ L 包含一个恒等项,即使F F F 的梯度很小,信息仍然能够流动。这种设计使得训练100层以上的Transformer成为可能。
解码器的自回归机制与交叉注意力
解码器在编码器的基础上增加了额外的复杂性,以支持自回归生成。每个解码器层包含三个子层:掩码自注意力、编码器-解码器交叉注意力和前馈网络。这种三层结构精妙地平衡了目标序列的内部依赖关系和源序列的条件信息。
掩码自注意力确保了训练时的因果性(causality):在预测位置i i i 的输出时,模型只能看到位置1到i − 1 i-1 i − 1 的信息。这通过在注意力分数上应用一个上三角掩码矩阵实现:
Attention ( Q , K , V ) = softmax ( Q K T d k + M c a u s a l ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M_{causal}\right)V
Attention ( Q , K , V ) = softmax ( d k Q K T + M c a u s a l ) V
其中M c a u s a l M_{causal} M c a u s a l 是一个上三角矩阵,对角线以上的元素为− ∞ -\infty − ∞ 。经过softmax后,这些位置的注意力权重变为0,有效地阻止了信息从未来流向过去。
交叉注意力层是解码器的关键创新,它允许解码器"查看"编码器的输出。在这一层中,查询向量来自解码器的前一层,而键和值向量来自编码器的最终输出:
CrossAttention ( Q d e c , K e n c , V e n c ) = softmax ( Q d e c ⋅ K e n c T d k ) ⋅ V e n c \text{CrossAttention}(Q_{dec}, K_{enc}, V_{enc}) = \text{softmax}\left(\frac{Q_{dec} \cdot K_{enc}^T}{\sqrt{d_k}}\right) \cdot V_{enc}
CrossAttention ( Q d e c , K e n c , V e n c ) = softmax ( d k Q d e c ⋅ K e n c T ) ⋅ V e n c
这种设计让解码器能够动态地关注源序列的不同部分,实现了灵活的源-目标对齐。在机器翻译中,这允许模型在生成每个目标词时关注相关的源语言词汇。
前馈网络层的关键作用
前馈网络看似简单,实际上扮演着多个关键角色。首先,它是Transformer中唯一的非线性来源。自注意力机制本质上是线性的(除了softmax),如果没有FFN的ReLU激活,整个模型将退化为线性变换的叠加。
FFN的两层结构:
FFN ( x ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2
FFN ( x ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2
其中隐藏层维度d f f d_{ff} d f f 通常是模型维度d m o d e l d_{model} d m o d e l 的4倍。这种"先扩展后压缩"的设计提供了足够的容量来学习复杂的特征变换。研究表明,FFN实际上充当了键值存储器,存储了大量的事实知识和模式。
现代变体中,激活函数的选择变得更加多样。GELU(Gaussian Error Linear Unit)因其平滑的梯度特性被BERT和GPT采用:
GELU ( x ) = x ⋅ Φ ( x ) = x ⋅ 1 2 [ 1 + erf ( x 2 ) ] \text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]
GELU ( x ) = x ⋅ Φ ( x ) = x ⋅ 2 1 [ 1 + erf ( 2 x ) ]
SwiGLU(Swish-Gated Linear Unit)在LLaMA等模型中表现出色,它通过门控机制实现了更灵活的信息流动:
SwiGLU ( x , W , V ) = ( x ⋅ W ⊗ Swish ( x ⋅ V ) ) ⋅ W 2 \text{SwiGLU}(x, W, V) = (x \cdot W \otimes \text{Swish}(x \cdot V)) \cdot W_2
SwiGLU ( x , W , V ) = ( x ⋅ W ⊗ Swish ( x ⋅ V ) ) ⋅ W 2
FFN的位置独立性也很重要。它对每个位置独立地进行相同的变换,这种设计保持了计算的并行性,同时为每个token提供了独立的特征提取能力。
残差连接与层归一化的深层意义
残差连接和层归一化是训练深层Transformer的两个支柱技术。残差连接的数学形式简单:
Output = x + Sublayer ( x ) \text{Output} = x + \text{Sublayer}(x)
Output = x + Sublayer ( x )
但其影响深远。它不仅解决了梯度消失问题,还改变了学习的本质。有了残差连接,每一层不再需要学习完整的表示,而是学习对输入的残差修正。这种"学习差异而非整体"的思想大大简化了优化问题。
层归一化的作用同样关键:
LayerNorm ( x ) = γ ⋅ x − μ σ + β \text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta
LayerNorm ( x ) = γ ⋅ σ x − μ + β
其中μ \mu μ 和σ \sigma σ 是在特征维度上计算的均值和标准差:
μ = 1 d ∑ i = 1 d x i , σ = 1 d ∑ i = 1 d ( x i − μ ) 2 \mu = \frac{1}{d}\sum_{i=1}^d x_i, \quad \sigma = \sqrt{\frac{1}{d}\sum_{i=1}^d (x_i - \mu)^2}
μ = d 1 i = 1 ∑ d x i , σ = d 1 i = 1 ∑ d ( x i − μ ) 2
与批归一化不同,层归一化不依赖批次统计,这对于处理变长序列和小批量训练至关重要。它稳定了激活值的分布,防止了深层网络中常见的内部协变量偏移(internal covariate shift)。
现代实践中,Pre-LN(层归一化在子层之前)已经取代了原始的Post-LN设计:
x + \text{Sublayer}(\text{LayerNorm}(x))$$ vs $$\text{LayerNorm}(x + \text{Sublayer}(x))
Pre-LN提供了更稳定的训练动态,使得训练数百层的Transformer成为可能。这个看似微小的改变实际上是GPT-3等超大规模模型成功的关键因素之一。
Teacher Forcing训练技术的巧妙设计
Teacher Forcing是训练自回归模型的标准技术,它在训练时使用真实的目标序列作为解码器的输入,而不是模型自己的预测。这种方法极大地加速了训练过程,并提高了训练的稳定性。
在实际实现中,目标序列被移位处理:如果目标序列是[ y 1 , y 2 , . . . , y T ] [y_1, y_2, ..., y_T] [ y 1 , y 2 , . . . , y T ] ,那么解码器的输入是[ ⟨ BOS ⟩ , y 1 , y 2 , . . . , y T − 1 ] [\langle\text{BOS}\rangle, y_1, y_2, ..., y_{T-1}] [ ⟨ BOS ⟩ , y 1 , y 2 , . . . , y T − 1 ] ,期望输出是[ y 1 , y 2 , . . . , y T , ⟨ EOS ⟩ ] [y_1, y_2, ..., y_T, \langle\text{EOS}\rangle] [ y 1 , y 2 , . . . , y T , ⟨ EOS ⟩ ] 。配合因果掩码,这种设置允许并行计算所有位置的损失,而不是像RNN那样需要逐步展开。
训练时的损失函数通常是交叉熵损失,对词表上的概率分布进行优化:
L = − ∑ i = 1 T log P ( y i ∣ y 1 , . . . , y i − 1 , x ) \mathcal{L} = -\sum_{i=1}^T \log P(y_i | y_1, ..., y_{i-1}, x)
L = − i = 1 ∑ T log P ( y i ∣ y 1 , . . . , y i − 1 , x )
原始论文还使用了标签平滑(label smoothing),将真实标签的概率从1.0降低到0.9,剩余的0.1平均分配给其他词汇:
P s m o o t h ( y ) = ( 1 − ϵ ) δ y , y ∗ + ϵ ∣ V ∣ P_{smooth}(y) = (1-\epsilon)\delta_{y,y^*} + \frac{\epsilon}{|V|}
P s m o o t h ( y ) = ( 1 − ϵ ) δ y , y ∗ + ∣ V ∣ ϵ
这种技术防止模型过度自信,提高了泛化能力。
然而,Teacher Forcing也带来了训练与推理之间的不匹配问题(exposure bias)。训练时模型总是看到正确的历史,而推理时必须基于自己的预测。现代方法如scheduled sampling试图通过在训练过程中逐渐引入模型预测来缓解这个问题。
注意力掩码的多样化应用
注意力掩码是Transformer灵活性的关键来源,通过不同的掩码模式,同一个架构可以适应各种任务需求。填充掩码(padding mask)处理变长序列,通过将填充位置的注意力权重设为0来忽略无意义的填充token。因果掩码(causal mask)确保自回归特性,防止模型看到未来信息。在BERT等双向模型中,使用了更复杂的掩码策略。训练时随机掩盖15%的token,模型需要基于上下文预测被掩盖的内容。这种掩码语言模型(MLM)训练创造了强大的双向表示。
现代研究还探索了稀疏注意力模式,如局部窗口注意力、条纹注意力和随机注意力。这些模式通过限制注意力的范围来降低计算复杂度,同时保持模型性能。例如,Longformer使用滑动窗口注意力处理局部上下文,配合全局注意力处理关键信息。
计算复杂度的深入分析
标准自注意力的时间复杂度是O ( n 2 ⋅ d ) O(n^2 \cdot d) O ( n 2 ⋅ d ) ,其中n n n 是序列长度,d d d 是隐藏维度。这个二次复杂度主要来自计算所有位置对之间的注意力分数。空间复杂度同样是O ( n 2 + n ⋅ d ) O(n^2 + n \cdot d) O ( n 2 + n ⋅ d ) ,需要存储完整的注意力矩阵和中间激活。
让我们详细分析各个组件的复杂度:
Query/Key/Value投影:O ( n ⋅ d 2 ) O(n \cdot d^2) O ( n ⋅ d 2 )
注意力分数计算:O ( n 2 ⋅ d ) O(n^2 \cdot d) O ( n 2 ⋅ d )
Softmax归一化:O ( n 2 ) O(n^2) O ( n 2 )
加权求和:O ( n 2 ⋅ d ) O(n^2 \cdot d) O ( n 2 ⋅ d )
输出投影:O ( n ⋅ d 2 ) O(n \cdot d^2) O ( n ⋅ d 2 )
当序列长度增加时,这种二次增长很快成为瓶颈。处理10,000个token的序列需要1亿次注意力计算,这解释了为什么早期Transformer模型的上下文长度通常限制在512-2048个token。相比之下,RNN的时间复杂度是O ( n ⋅ d 2 ) O(n \cdot d^2) O ( n ⋅ d 2 ) ,但需要O ( n ) O(n) O ( n ) 次序列操作,无法并行化。
内存带宽也是一个关键考虑因素。现代GPU的计算能力远超内存带宽,这意味着数据传输往往是瓶颈。FlashAttention等优化技术通过重新组织计算来减少内存访问,实现了2-3倍的实际加速,即使理论复杂度没有改变。
Transformer相对于传统架构的革命性优势
Transformer相对于RNN的最大优势是完全并行化。在RNN中,时间步t t t 的计算必须等待t − 1 t-1 t − 1 完成,这种序列依赖使得训练极其缓慢。Transformer可以同时处理整个序列,将训练时间从周缩短到天甚至小时。在长距离依赖建模方面,Transformer的优势更加明显。RNN中相距n n n 个位置的信息需要经过n n n 次状态转换,每次转换都可能造成信息损失。Transformer通过自注意力机制提供了直接连接,任意两个位置之间只有常数距离。这种"全连接"特性使得Transformer能够轻松处理数千个token的序列。
相比CNN,Transformer的优势在于灵活性。CNN的感受野是固定的,需要很深的网络才能获得全局视野。而Transformer从第一层就具有全局感受野,能够根据内容动态地分配注意力。这种内容相关的处理方式特别适合自然语言,因为语言中的依赖关系往往是非局部的和不规则的。
实践中的关键优化技巧
混合精度训练已经成为训练大规模Transformer的标准做法。使用FP16进行前向和反向传播,同时保持FP32的主权重和优化器状态,可以将内存使用减少50%,在配备Tensor Core的GPU上实现2倍加速。损失缩放(loss scaling)防止梯度下溢:
Loss s c a l e d = Loss × S \text{Loss}_{scaled} = \text{Loss} \times S
Loss s c a l e d = Loss × S
其中S S S 是缩放因子,通常从2 15 2^{15} 2 1 5 开始,动态调整。
梯度检查点(gradient checkpointing)通过重计算来节省内存。不保存所有的中间激活,而是只保存关键检查点,在反向传播时重新计算需要的激活。这种方法可以将内存使用从O ( n ) O(n) O ( n ) 减少到O ( n ) O(\sqrt{n}) O ( n ) ,代价是约20%的速度损失。
学习率调度对Transformer训练至关重要。原始论文使用的warmup策略已经成为标准:
l r = d m o d e l − 0.5 ⋅ min ( step − 0.5 , step ⋅ warmup_steps − 1.5 ) lr = d_{model}^{-0.5} \cdot \min(\text{step}^{-0.5}, \text{step} \cdot \text{warmup\_steps}^{-1.5})
l r = d m o d e l − 0 . 5 ⋅ min ( step − 0 . 5 , step ⋅ warmup_steps − 1 . 5 )
这种策略在训练初期线性增加学习率,避免了大梯度造成的不稳定,然后按照反平方根规律衰减,确保模型能够收敛到好的局部最优。
初始化策略同样重要。Xavier初始化确保前向传播和反向传播时的方差保持一致:
W ∼ N ( 0 , 2 n i n + n o u t ) W \sim \mathcal{N}\left(0, \frac{2}{n_{in} + n_{out}}\right)
W ∼ N ( 0 , n i n + n o u t 2 )
对于深层网络,合适的初始化能够决定训练的成败。
附录:Transformer核心算法的推导
A. 自注意力机制的完整梯度推导
考虑自注意力的前向传播过程,我们需要推导损失函数L \mathcal{L} L 关于各个参数的梯度。
设输入序列X ∈ R n × d X \in \mathbb{R}^{n \times d} X ∈ R n × d ,其中n n n 是序列长度,d d d 是特征维度。
首先计算查询、键和值:
Q = X W Q , K = X W K , V = X W V Q = XW^Q, \quad K = XW^K, \quad V = XW^V
Q = X W Q , K = X W K , V = X W V
注意力分数矩阵:
S = Q K T d k = X W Q ( X W K ) T d k S = \frac{QK^T}{\sqrt{d_k}} = \frac{XW^Q(XW^K)^T}{\sqrt{d_k}}
S = d k Q K T = d k X W Q ( X W K ) T
应用softmax:
A = softmax ( S ) = exp ( S ) ∑ j exp ( S : , j ) A = \text{softmax}(S) = \frac{\exp(S)}{\sum_j \exp(S_{:,j})}
A = softmax ( S ) = ∑ j exp ( S : , j ) exp ( S )
输出:
Y = A V = softmax ( S ) ⋅ X W V Y = AV = \text{softmax}(S) \cdot XW^V
Y = A V = softmax ( S ) ⋅ X W V
现在推导反向传播。设∂ L ∂ Y = d Y \frac{\partial \mathcal{L}}{\partial Y} = dY ∂ Y ∂ L = d Y 已知。
首先,关于V V V 的梯度:
∂ L ∂ V = A T ⋅ d Y \frac{\partial \mathcal{L}}{\partial V} = A^T \cdot dY
∂ V ∂ L = A T ⋅ d Y
关于A A A 的梯度:
∂ L ∂ A = d Y ⋅ V T \frac{\partial \mathcal{L}}{\partial A} = dY \cdot V^T
∂ A ∂ L = d Y ⋅ V T
关于S S S 的梯度需要通过softmax的Jacobian:
∂ A i j ∂ S i k = A i j ( δ j k − A i k ) \frac{\partial A_{ij}}{\partial S_{ik}} = A_{ij}(\delta_{jk} - A_{ik})
∂ S i k ∂ A i j = A i j ( δ j k − A i k )
因此:
∂ L ∂ S i j = A i j ∑ k ∂ L ∂ A i k ( 1 − A i j ) − A i j 2 ∑ k ≠ j ∂ L ∂ A i k \frac{\partial \mathcal{L}}{\partial S_{ij}} = A_{ij} \sum_k \frac{\partial \mathcal{L}}{\partial A_{ik}}(1 - A_{ij}) - A_{ij}^2 \sum_{k \neq j} \frac{\partial \mathcal{L}}{\partial A_{ik}}
∂ S i j ∂ L = A i j k ∑ ∂ A i k ∂ L ( 1 − A i j ) − A i j 2 k = j ∑ ∂ A i k ∂ L
简化后:
∂ L ∂ S = A ⊙ ( d A − rowsum ( d A ⊙ A ) ) \frac{\partial \mathcal{L}}{\partial S} = A \odot (dA - \text{rowsum}(dA \odot A))
∂ S ∂ L = A ⊙ ( d A − rowsum ( d A ⊙ A ) )
其中⊙ \odot ⊙ 表示逐元素乘法,rowsum \text{rowsum} rowsum 表示按行求和并广播。
关于Q Q Q 和K K K 的梯度:
∂ L ∂ Q = 1 d k ∂ L ∂ S ⋅ K \frac{\partial \mathcal{L}}{\partial Q} = \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial S} \cdot K
∂ Q ∂ L = d k 1 ∂ S ∂ L ⋅ K
∂ L ∂ K = 1 d k ∂ L ∂ S T ⋅ Q \frac{\partial \mathcal{L}}{\partial K} = \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial S}^T \cdot Q
∂ K ∂ L = d k 1 ∂ S ∂ L T ⋅ Q
最终,关于权重矩阵的梯度:
∂ L ∂ W Q = X T ⋅ ∂ L ∂ Q \frac{\partial \mathcal{L}}{\partial W^Q} = X^T \cdot \frac{\partial \mathcal{L}}{\partial Q}
∂ W Q ∂ L = X T ⋅ ∂ Q ∂ L
∂ L ∂ W K = X T ⋅ ∂ L ∂ K \frac{\partial \mathcal{L}}{\partial W^K} = X^T \cdot \frac{\partial \mathcal{L}}{\partial K}
∂ W K ∂ L = X T ⋅ ∂ K ∂ L
∂ L ∂ W V = X T ⋅ ∂ L ∂ V \frac{\partial \mathcal{L}}{\partial W^V} = X^T \cdot \frac{\partial \mathcal{L}}{\partial V}
∂ W V ∂ L = X T ⋅ ∂ V ∂ L
B. 多头注意力的参数效率分析
多头注意力将d m o d e l d_{model} d m o d e l 维度分割成h h h 个头,每个头的维度为d h = d m o d e l / h d_h = d_{model}/h d h = d m o d e l / h 。
总参数量:
单头注意力:3 × d m o d e l 2 + d m o d e l 2 = 4 d m o d e l 2 3 \times d_{model}^2 + d_{model}^2 = 4d_{model}^2 3 × d m o d e l 2 + d m o d e l 2 = 4 d m o d e l 2
多头注意力:h × 3 × d m o d e l × d h + d m o d e l 2 = 3 d m o d e l 2 + d m o d e l 2 = 4 d m o d e l 2 h \times 3 \times d_{model} \times d_h + d_{model}^2 = 3d_{model}^2 + d_{model}^2 = 4d_{model}^2 h × 3 × d m o d e l × d h + d m o d e l 2 = 3 d m o d e l 2 + d m o d e l 2 = 4 d m o d e l 2
计算复杂度分析:
投影操作:O ( 3 n ⋅ d m o d e l 2 ) O(3n \cdot d_{model}^2) O ( 3 n ⋅ d m o d e l 2 )
注意力计算(每个头):O ( n 2 ⋅ d h ) O(n^2 \cdot d_h) O ( n 2 ⋅ d h )
总注意力计算:O ( h ⋅ n 2 ⋅ d h ) = O ( n 2 ⋅ d m o d e l ) O(h \cdot n^2 \cdot d_h) = O(n^2 \cdot d_{model}) O ( h ⋅ n 2 ⋅ d h ) = O ( n 2 ⋅ d m o d e l )
输出投影:O ( n ⋅ d m o d e l 2 ) O(n \cdot d_{model}^2) O ( n ⋅ d m o d e l 2 )
总复杂度:O ( n 2 ⋅ d m o d e l + n ⋅ d m o d e l 2 ) O(n^2 \cdot d_{model} + n \cdot d_{model}^2) O ( n 2 ⋅ d m o d e l + n ⋅ d m o d e l 2 )
这表明多头注意力在不增加计算复杂度的情况下,提供了h h h 倍的表示子空间。
C. 位置编码的线性变换性质证明
证明:对于正弦位置编码,存在线性变换将位置p o s pos p o s 的编码映射到位置p o s + k pos+k p o s + k 的编码。
设位置编码为:
P E p o s , 2 i = sin ( ω i ⋅ p o s ) PE_{pos,2i} = \sin(\omega_i \cdot pos)
P E p o s , 2 i = sin ( ω i ⋅ p o s )
P E p o s , 2 i + 1 = cos ( ω i ⋅ p o s ) PE_{pos,2i+1} = \cos(\omega_i \cdot pos)
P E p o s , 2 i + 1 = cos ( ω i ⋅ p o s )
其中ω i = 1 / 1000 0 2 i / d m o d e l \omega_i = 1/10000^{2i/d_{model}} ω i = 1 / 1 0 0 0 0 2 i / d m o d e l 。
对于固定的偏移k k k ,我们有:
P E p o s + k , 2 i = sin ( ω i ⋅ ( p o s + k ) ) = sin ( ω i ⋅ p o s + ω i ⋅ k ) PE_{pos+k,2i} = \sin(\omega_i \cdot (pos + k)) = \sin(\omega_i \cdot pos + \omega_i \cdot k)
P E p o s + k , 2 i = sin ( ω i ⋅ ( p o s + k ) ) = sin ( ω i ⋅ p o s + ω i ⋅ k )
使用三角恒等式:
sin ( α + β ) = sin ( α ) cos ( β ) + cos ( α ) sin ( β ) \sin(\alpha + \beta) = \sin(\alpha)\cos(\beta) + \cos(\alpha)\sin(\beta)
sin ( α + β ) = sin ( α ) cos ( β ) + cos ( α ) sin ( β )
cos ( α + β ) = cos ( α ) cos ( β ) − sin ( α ) sin ( β ) \cos(\alpha + \beta) = \cos(\alpha)\cos(\beta) - \sin(\alpha)\sin(\beta)
cos ( α + β ) = cos ( α ) cos ( β ) − sin ( α ) sin ( β )
因此:
[ sin ( ω i ( p o s + k ) ) cos ( ω i ( p o s + k ) ) ] = [ cos ( ω i k ) sin ( ω i k ) − sin ( ω i k ) cos ( ω i k ) ] [ sin ( ω i ⋅ p o s ) cos ( ω i ⋅ p o s ) ] \begin{bmatrix}
\sin(\omega_i(pos+k)) \\
\cos(\omega_i(pos+k))
\end{bmatrix} = \begin{bmatrix}
\cos(\omega_i k) & \sin(\omega_i k) \\
-\sin(\omega_i k) & \cos(\omega_i k)
\end{bmatrix} \begin{bmatrix}
\sin(\omega_i \cdot pos) \\
\cos(\omega_i \cdot pos)
\end{bmatrix} [ sin ( ω i ( p o s + k ) ) cos ( ω i ( p o s + k ) ) ] = [ cos ( ω i k ) − sin ( ω i k ) sin ( ω i k ) cos ( ω i k ) ] [ sin ( ω i ⋅ p o s ) cos ( ω i ⋅ p o s ) ]
这是一个旋转矩阵变换,证明了相对位置可以通过线性变换学习。
D. 残差连接的梯度流分析
考虑L L L 层的残差网络,第l l l 层的输出为:
x l = x l − 1 + F l ( x l − 1 ) x_l = x_{l-1} + F_l(x_{l-1})
x l = x l − 1 + F l ( x l − 1 )
递归展开:
x L = x 0 + ∑ l = 1 L F l ( x l − 1 ) x_L = x_0 + \sum_{l=1}^L F_l(x_{l-1})
x L = x 0 + l = 1 ∑ L F l ( x l − 1 )
损失关于x l x_l x l 的梯度:
∂ L ∂ x l = ∂ L ∂ x L ∂ x L ∂ x l \frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L} \frac{\partial x_L}{\partial x_l}
∂ x l ∂ L = ∂ x L ∂ L ∂ x l ∂ x L
由于残差结构:
∂ x L ∂ x l = ∂ x L ∂ x L − 1 ⋯ ∂ x l + 1 ∂ x l \frac{\partial x_L}{\partial x_l} = \frac{\partial x_L}{\partial x_{L-1}} \cdots \frac{\partial x_{l+1}}{\partial x_l}
∂ x l ∂ x L = ∂ x L − 1 ∂ x L ⋯ ∂ x l ∂ x l + 1
= ∏ i = l + 1 L ( I + ∂ F i ( x i − 1 ) ∂ x i − 1 ) = \prod_{i=l+1}^L \left(I + \frac{\partial F_i(x_{i-1})}{\partial x_{i-1}}\right)
= i = l + 1 ∏ L ( I + ∂ x i − 1 ∂ F i ( x i − 1 ) )
展开后包含恒等项I I I ,确保梯度不会消失。即使所有∂ F i ∂ x i − 1 ≈ 0 \frac{\partial F_i}{\partial x_{i-1}} \approx 0 ∂ x i − 1 ∂ F i ≈ 0 ,梯度仍然至少为:
∂ L ∂ x l ≈ ∂ L ∂ x L \frac{\partial \mathcal{L}}{\partial x_l} \approx \frac{\partial \mathcal{L}}{\partial x_L}
∂ x l ∂ L ≈ ∂ x L ∂ L
E. 层归一化的梯度计算
给定输入x ∈ R d x \in \mathbb{R}^d x ∈ R d ,层归一化定义为:
y = γ ⊙ x − μ σ + β y = \gamma \odot \frac{x - \mu}{\sigma} + \beta
y = γ ⊙ σ x − μ + β
其中:
μ = 1 d ∑ i = 1 d x i , σ = 1 d ∑ i = 1 d ( x i − μ ) 2 + ϵ \mu = \frac{1}{d}\sum_{i=1}^d x_i, \quad \sigma = \sqrt{\frac{1}{d}\sum_{i=1}^d (x_i - \mu)^2 + \epsilon}
μ = d 1 i = 1 ∑ d x i , σ = d 1 i = 1 ∑ d ( x i − μ ) 2 + ϵ
设标准化后的值为x ^ = x − μ σ \hat{x} = \frac{x - \mu}{\sigma} x ^ = σ x − μ 。
反向传播时,给定∂ L ∂ y = d y \frac{\partial \mathcal{L}}{\partial y} = dy ∂ y ∂ L = d y :
∂ L ∂ γ = d y ⊙ x ^ \frac{\partial \mathcal{L}}{\partial \gamma} = dy \odot \hat{x}
∂ γ ∂ L = d y ⊙ x ^
∂ L ∂ β = d y \frac{\partial \mathcal{L}}{\partial \beta} = dy
∂ β ∂ L = d y
关于x ^ \hat{x} x ^ 的梯度:
∂ L ∂ x ^ = d y ⊙ γ \frac{\partial \mathcal{L}}{\partial \hat{x}} = dy \odot \gamma
∂ x ^ ∂ L = d y ⊙ γ
关于x x x 的梯度需要考虑μ \mu μ 和σ \sigma σ 的依赖:
∂ L ∂ x i = ∂ L ∂ x ^ i ∂ x ^ i ∂ x i + ∑ j = 1 d ∂ L ∂ x ^ j ∂ x ^ j ∂ μ ∂ μ ∂ x i + ∑ j = 1 d ∂ L ∂ x ^ j ∂ x ^ j ∂ σ ∂ σ ∂ x i \frac{\partial \mathcal{L}}{\partial x_i} = \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial x_i} + \sum_{j=1}^d \frac{\partial \mathcal{L}}{\partial \hat{x}_j} \frac{\partial \hat{x}_j}{\partial \mu} \frac{\partial \mu}{\partial x_i} + \sum_{j=1}^d \frac{\partial \mathcal{L}}{\partial \hat{x}_j} \frac{\partial \hat{x}_j}{\partial \sigma} \frac{\partial \sigma}{\partial x_i}
∂ x i ∂ L = ∂ x ^ i ∂ L ∂ x i ∂ x ^ i + j = 1 ∑ d ∂ x ^ j ∂ L ∂ μ ∂ x ^ j ∂ x i ∂ μ + j = 1 ∑ d ∂ x ^ j ∂ L ∂ σ ∂ x ^ j ∂ x i ∂ σ
经过详细推导:
∂ L ∂ x = 1 σ ( d x ^ − ∑ i d x ^ i − x ^ ∑ i d x ^ i x ^ i ) ⊙ γ / d \frac{\partial \mathcal{L}}{\partial x} = \frac{1}{\sigma} \left(d\hat{x} - \sum_i d\hat{x}_i - \hat{x} \sum_i d\hat{x}_i \hat{x}_i\right) \odot \gamma / d
∂ x ∂ L = σ 1 ( d x ^ − i ∑ d x ^ i − x ^ i ∑ d x ^ i x ^ i ) ⊙ γ / d
F. Softmax温度缩放的信息论解释
Softmax函数的温度版本:
p i = exp ( z i / T ) ∑ j exp ( z j / T ) p_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}
p i = ∑ j exp ( z j / T ) exp ( z i / T )
当T → 0 T \to 0 T → 0 时,分布趋向于one-hot(argmax):
lim T → 0 p i = { 1 if i = arg max j z j 0 otherwise \lim_{T \to 0} p_i = \begin{cases}
1 & \text{if } i = \arg\max_j z_j \\
0 & \text{otherwise}
\end{cases} T → 0 lim p i = { 1 0 if i = arg max j z j otherwise
当T → ∞ T \to \infty T → ∞ 时,分布趋向均匀:
lim T → ∞ p i = 1 n \lim_{T \to \infty} p_i = \frac{1}{n}
T → ∞ lim p i = n 1
熵的变化:
H ( p ) = − ∑ i p i log p i H(p) = -\sum_i p_i \log p_i
H ( p ) = − i ∑ p i log p i
温度控制了分布的熵:
在注意力机制中,T = d k T = \sqrt{d_k} T = d k 的选择基于:
Var [ q T k ] = d k ⋅ Var [ q i ] ⋅ Var [ k i ] = d k \text{Var}[q^T k] = d_k \cdot \text{Var}[q_i] \cdot \text{Var}[k_i] = d_k
Var [ q T k ] = d k ⋅ Var [ q i ] ⋅ Var [ k i ] = d k
通过除以d k \sqrt{d_k} d k ,使方差归一化为1,保持合适的熵水平。
G. Teacher Forcing的偏差-方差分解
考虑序列生成的损失函数:
L = E y ∼ p d a t a [ − log p m o d e l ( y ∣ x ) ] \mathcal{L} = \mathbb{E}_{y \sim p_{data}}[-\log p_{model}(y|x)]
L = E y ∼ p d a t a [ − log p m o d e l ( y ∣ x ) ]
在Teacher Forcing下:
L T F = ∑ t = 1 T E y t ∼ p d a t a [ − log p m o d e l ( y t ∣ y 1 : t − 1 ∗ , x ) ] \mathcal{L}_{TF} = \sum_{t=1}^T \mathbb{E}_{y_t \sim p_{data}}[-\log p_{model}(y_t|y_{1:t-1}^*, x)]
L T F = t = 1 ∑ T E y t ∼ p d a t a [ − log p m o d e l ( y t ∣ y 1 : t − 1 ∗ , x ) ]
在自回归推理下:
L A R = ∑ t = 1 T E y t ∼ p m o d e l [ − log p m o d e l ( y t ∣ y ^ 1 : t − 1 , x ) ] \mathcal{L}_{AR} = \sum_{t=1}^T \mathbb{E}_{y_t \sim p_{model}}[-\log p_{model}(y_t|\hat{y}_{1:t-1}, x)]
L A R = t = 1 ∑ T E y t ∼ p m o d e l [ − log p m o d e l ( y t ∣ y ^ 1 : t − 1 , x ) ]
两者的差异(exposure bias):
Δ = L A R − L T F \Delta = \mathcal{L}_{AR} - \mathcal{L}_{TF}
Δ = L A R − L T F
可以分解为:
Δ = ∑ t = 1 T D K L ( p d a t a ( y 1 : t − 1 ) ∣ ∣ p m o d e l ( y 1 : t − 1 ) ) ⏟ 分布偏移 + E [ Var [ p m o d e l ( y t ∣ ⋅ ) ] ] ⏟ 误差累积 \Delta = \underbrace{\sum_{t=1}^T D_{KL}(p_{data}(y_{1:t-1}) || p_{model}(y_{1:t-1}))}_{\text{分布偏移}} + \underbrace{\mathbb{E}[\text{Var}[p_{model}(y_t|\cdot)]]}_{\text{误差累积}}
Δ = 分布偏移 t = 1 ∑ T D K L ( p d a t a ( y 1 : t − 1 ) ∣ ∣ p m o d e l ( y 1 : t − 1 ) ) + 误差累积 E [ Var [ p m o d e l ( y t ∣ ⋅ ) ] ]
这解释了为什么长序列生成质量会下降。
H. 计算效率的硬件感知优化
现代GPU的内存层次:
寄存器:~20KB,延迟1周期
L1缓存:~128KB,延迟~100周期
L2缓存:~40MB,延迟~200周期
HBM:~80GB,延迟~300周期
标准注意力的内存访问模式:
从HBM读取Q , K , V Q, K, V Q , K , V :3 n d 3nd 3 n d 字节
计算S = Q K T S = QK^T S = Q K T ,写入HBM:n 2 n^2 n 2 字节
计算softmax,读写S S S :2 n 2 2n^2 2 n 2 字节
计算Y = softmax ( S ) V Y = \text{softmax}(S)V Y = softmax ( S ) V :读n 2 + n d n^2 + nd n 2 + n d ,写n d nd n d 字节
总内存访问:O ( n 2 + n d ) O(n^2 + nd) O ( n 2 + n d ) 字节
FlashAttention通过分块计算优化:
将Q , K , V Q, K, V Q , K , V 分成块大小B r × B c B_r \times B_c B r × B c
在SRAM中完成块内计算
使用在线softmax避免存储完整注意力矩阵
内存访问减少到:O ( n 2 d / M ) O(n^2d/M) O ( n 2 d / M ) ,其中M M M 是SRAM大小。
当M = Θ ( d ) M = \Theta(d) M = Θ ( d ) 时,内存访问变为O ( n 2 ) O(n^2) O ( n 2 ) ,实现了与计算复杂度的匹配。
I. 注意力机制的核方法解释
自注意力可以视为隐式的核方法。定义核函数:
κ ( x i , x j ) = exp ( ⟨ ϕ ( x i ) , ϕ ( x j ) ⟩ d ) \kappa(x_i, x_j) = \exp\left(\frac{\langle \phi(x_i), \phi(x_j) \rangle}{\sqrt{d}}\right)
κ ( x i , x j ) = exp ( d ⟨ ϕ ( x i ) , ϕ ( x j ) ⟩ )
其中ϕ ( x ) = x W Q \phi(x) = xW^Q ϕ ( x ) = x W Q 或x W K xW^K x W K 是特征映射。
标准注意力:
Attention ( X ) = ∑ j κ ( x i , x j ) v j ∑ j κ ( x i , x j ) \text{Attention}(X) = \frac{\sum_j \kappa(x_i, x_j) v_j}{\sum_j \kappa(x_i, x_j)}
Attention ( X ) = ∑ j κ ( x i , x j ) ∑ j κ ( x i , x j ) v j
这是核密度估计的形式。线性注意力通过显式特征映射近似:
ϕ ( x ) = 1 m [ sin ( W x ) , cos ( W x ) ] \phi(x) = \frac{1}{\sqrt{m}}[\sin(Wx), \cos(Wx)]
ϕ ( x ) = m 1 [ sin ( W x ) , cos ( W x ) ]
使得:
κ ( x i , x j ) ≈ ⟨ ϕ ( x i ) , ϕ ( x j ) ⟩ \kappa(x_i, x_j) \approx \langle \phi(x_i), \phi(x_j) \rangle
κ ( x i , x j ) ≈ ⟨ ϕ ( x i ) , ϕ ( x j ) ⟩
计算变为:
LinearAttention ( X ) = ϕ ( Q ) ( ϕ ( K ) T V ) ϕ ( Q ) 1 \text{LinearAttention}(X) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)\mathbf{1}}
LinearAttention ( X ) = ϕ ( Q ) 1 ϕ ( Q ) ( ϕ ( K ) T V )
复杂度从O ( n 2 d ) O(n^2d) O ( n 2 d ) 降至O ( n m d ) O(nmd) O ( n m d ) ,其中m m m 是特征维度。
J. 优化器动量与Transformer训练的相互作用
Adam优化器在Transformer训练中的关键作用:
一阶矩估计(动量):
m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t
m t = β 1 m t − 1 + ( 1 − β 1 ) g t
二阶矩估计(自适应学习率):
v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2
v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2
偏差修正:
m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}
m ^ t = 1 − β 1 t m t , v ^ t = 1 − β 2 t v t
参数更新:
θ t + 1 = θ t − η m ^ t v ^ t + ϵ \theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
θ t + 1 = θ t − η v ^ t + ϵ m ^ t
对于Transformer的深层结构,梯度方差很大。Adam的自适应学习率特别重要:
SNR = ∣ E [ g ] ∣ 2 Var [ g ] \text{SNR} = \frac{|\mathbb{E}[g]|^2}{\text{Var}[g]}
SNR = Var [ g ] ∣ E [ g ] ∣ 2
当SNR低时(深层网络常见),Adam通过v t \sqrt{v_t} v t 归一化提高稳定性。
Warmup与Adam的交互:
η t = min ( t T w a r m u p , 1 ) ⋅ η b a s e \eta_t = \min\left(\frac{t}{T_{warmup}}, 1\right) \cdot \eta_{base}
η t = min ( T w a r m u p t , 1 ) ⋅ η b a s e
在warmup阶段,二阶矩估计逐渐稳定,防止训练初期的大幅震荡。
评论(0)