大语言模型的核心算法——简要解析
Transformer架构的数学本质与演进
自注意力机制的核心原理
Transformer架构的灵魂在于自注意力机制,它允许模型在处理序列中的每个元素时,动态地关注序列中的所有其他位置。从数学角度看,自注意力的计算过程可以表达为:
Attention(Q,K,V)=softmax(dkQKT)V
当输入序列X∈Rn×dmodel进入自注意力层时,首先通过三个不同的线性变换生成查询(Query)、键(Key)和值(Value)矩阵:
Q=XWQ,K=XWK,V=XWV
其中WQ,WK,WV∈Rdmodel×dk是可学习的参数矩阵。查询矩阵Q代表当前位置想要获取的信息,键矩阵K代表每个位置能够提供的信息特征,而值矩阵V则是实际要传递的信息内容。通过计算Q和K的点积,模型得到了一个注意力分数矩阵,表示每个位置对其他所有位置的关注程度。
这里的缩放因子dk起着关键作用。随着维度dk的增加,点积的方差会线性增长,可能导致softmax函数进入饱和区域,梯度变得极小。具体来说,假设q和k的分量是独立同分布的随机变量,均值为0,方差为1,则它们点积的方差为dk。通过除以dk进行缩放,确保了注意力分数保持在合理的范围内,维持了训练的稳定性。
多头注意力(Multi-Head Attention)进一步扩展了这一机制。通过并行运行多个注意力头,每个头关注不同的表示子空间,模型能够同时捕获多种类型的依赖关系:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个注意力头计算为:
headi=Attention(QWiQ,KWiK,VWiV)
这里WiQ∈Rdmodel×dk,WiK∈Rdmodel×dk,WiV∈Rdmodel×dv,而WO∈Rhdv×dmodel将所有头的输出整合。这种设计让模型能够从多个角度理解输入序列的结构和语义。
位置编码的革命性演进
原始Transformer使用正弦位置编码来注入序列的顺序信息:
PE(pos,2i)=sin(100002i/dmodelpos)
PE(pos,2i+1)=cos(100002i/dmodelpos)
这种编码方式虽然简单有效,但在处理超长序列时存在局限性。
2024年,旋转位置嵌入(RoPE)已经成为事实标准,被LLaMA、GPT-NeoX等主流模型广泛采用。RoPE的核心创新在于通过旋转矩阵编码绝对位置信息,同时自然地融入相对位置依赖。对于位置m的token embedding xm,RoPE应用旋转变换:
f(xm,m)=Rmxm
其中旋转矩阵Rm定义为:
Rm=⎝⎜⎜⎜⎜⎜⎜⎛cosmθ0sinmθ000⋮−sinmθ0cosmθ000⋮00cosmθ1sinmθ1⋮00−sinmθ1cosmθ1⋮⋯⋯⋯⋯⋱⎠⎟⎟⎟⎟⎟⎟⎞
频率参数θi=10000−2i/d。RoPE的优雅之处在于,当计算两个位置m和n的注意力分数时:
⟨f(q,m),f(k,n)⟩=qTRmTRnk=qTRn−mk
点积结果自然包含了相对位置信息(n−m)。这使得模型能够处理训练时未见过的序列长度,实现了真正的长度外推能力。
相比之下,ALiBi(Attention with Linear Biases)采用了更加直接的方法,通过在注意力分数中添加线性偏置来编码位置信息:
AttentionALiBi(Q,K,V)=softmax(dkQKT−m⋅∣i−j∣)V
其中m是头特定的斜率参数,∣i−j∣是位置差的绝对值。
Flash Attention与硬件协同优化
2024年7月发布的Flash Attention v3代表了算法与硬件协同设计的巅峰。标准注意力计算的时间复杂度为O(N2d),空间复杂度为O(N2),其中N是序列长度,d是特征维度。Flash Attention通过分块计算和内存优化,将空间复杂度降低到O(N)。
其核心思想是将注意力矩阵分块计算,避免存储完整的N×N注意力矩阵。对于输入序列,Flash Attention将其分为大小为Br×Bc的块,并使用以下更新规则:
mijnew=max(mijold,m~ij)
lijnew=emijold−mijnewlijold+em~ij−mijnewl~ij
Oijnew=lijnewlijoldemijold−mijnewOijold+lijnewl~ijem~ij−mijnewO~ij
其中m是最大值,l是指数和,O是输出。这种增量更新方式允许逐块处理,显著减少了内存访问。
在FP16精度下,Flash Attention v3达到了740 TFLOPs/s的性能,相当于H100理论峰值的75%利用率。它还支持FP8低精度计算,在保持数值稳定性的同时达到了1.2 PFLOPs/s的性能。
分组查询注意力(GQA)作为另一项重要创新,通过将查询头分组共享键值对,在效率与质量之间找到了完美平衡。假设有H个查询头和G个键值组(G<H),则第i个查询头使用第⌊i⋅G/H⌋组的键值对:
GQAi=Attention(Qi,K⌊i⋅G/H⌋,V⌊i⋅G/H⌋)
以Llama 2 70B为例,使用8个组可以将KV缓存减少87.5%,而模型性能几乎不受影响。
主流模型的架构创新与技术突破
GPT-4的稀疏专家混合架构
根据已知信息,GPT-4采用了Mixture of Experts(MoE)架构,总参数约1.8万亿,包含16个专家网络,每个约1110亿参数。关键的是,每次前向传播仅激活2个专家,实际使用约2200亿参数。
MoE架构的核心在于路由机制。对于输入token x,路由网络计算每个专家的选择概率:
G(x)=TopK(softmax(Wg⋅x+ϵ),k)
其中Wg是路由权重,ϵ是噪声项用于探索,TopK选择概率最高的k个专家。最终输出为:
y=i=1∑kGi(x)⋅Ei(x)
这里Ei(x)是第i个专家的输出,Gi(x)是对应的门控权重。
负载均衡是MoE训练的关键挑战。DeepSeek-V3通过auxiliary-loss-free策略优化了这一过程,定义负载均衡损失为:
Lbalance=α⋅i=1∑N(fi−N1)2
其中fi是第i个专家的平均激活频率,N是专家总数,α是平衡系数。
Claude的Constitutional AI训练范式
Anthropic的Constitutional AI代表了一种全新的模型对齐方法。不同于传统的RLHF需要大量人工标注,Constitutional AI通过两个阶段实现自监督学习。
在第一阶段,模型生成响应后,基于预定义的宪法原则进行自我批评和修订。这个过程可以形式化为条件生成:$$p!\left(y_{\text{revised}} \mid x,, y_{\text{initial}},, \mathcal{C}\right)
= \prod_{t=1}^{T} p!\left(y_t \mid y_{<t},, x,, y_{\text{initial}},, \mathcal{C}\right)$$
其中 C 表示宪法原则集合,yinitial 是初始响应,yrevised 是修订后的响应。
在第二阶段,使用AI反馈(RLAIF)训练偏好模型。偏好模型的目标是学习条件分布:
P(y1≻y2∣x,C)=σ(rθ(x,y1,C)−rθ(x,y2,C))
其中rθ是参数化的奖励函数,σ是sigmoid函数,y1≻y2表示y1优于y2。
LLaMA 3.1的密集架构优化
Meta的LLaMA 3.1 405B虽然采用密集架构而非MoE,但通过一系列技术创新实现了卓越性能。模型使用RMSNorm替代LayerNorm:
RMSNorm(x)=RMS(x)x⋅γ
其中:
RMS(x)=d1i=1∑dxi2
这种简化不仅提高了训练稳定性,还减少了约7%的计算开销。
SwiGLU激活函数是另一项关键创新:
SwiGLU(x)=Swish(xW1)⊗(xW2)
其中Swish函数定义为:
Swish(x)=x⋅σ(βx)=1+e−βxx
当β=1时,SwiGLU提供了比ReLU更平滑的梯度,其导数为:
dxdSwish=σ(βx)+βx⋅σ(βx)(1−σ(βx))
这种平滑性有助于模型学习更复杂的表示,特别是在深层网络中。
训练技术的范式转变
从RLHF到DPO的演进
强化学习从人类反馈(RLHF)曾是大模型对齐的黄金标准。RLHF的目标是最大化期望奖励,同时限制与参考策略的偏离:
πθmaxEx∼D,y∼πθ(y∣x)[Rϕ(x,y)]−βDKL[πθ(y∣x)∣∣πref(y∣x)]
通过拉格朗日乘数法,可以得到最优策略的闭式解:
π∗(y∣x)=Z(x)1πref(y∣x)exp(βR(x,y))
其中Z(x)是配分函数:
Z(x)=y∑πref(y∣x)exp(βR(x,y))
直接偏好优化(DPO)的关键洞察是,可以从最优策略的形式反推出奖励函数:
R(x,y)=βlogπref(y∣x)π∗(y∣x)+βlogZ(x)
基于Bradley-Terry模型,人类偏好的概率可以表示为:
P(yw≻yl∣x)=σ(R(x,yw)−R(x,yl))
代入奖励函数的表达式并化简(注意Z(x)项会抵消),得到DPO的损失函数:
LDPO(πθ;πref)=−E(x,yw,yl)∼D[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]
这种方法消除了奖励模型训练阶段,直接优化策略参数,计算效率提升11%,内存使用减少11%。
长上下文处理的技术突破
Ring Attention通过分布式计算突破了序列长度限制。假设有P个设备,序列长度为N,每个设备处理N/P长度的子序列。注意力计算分解为:
Attentionglobal=p=1⋃PLocalAttentionp+CrossAttentionp→(p+1)modP
在环形拓扑中,设备p在第t轮接收来自设备(p−t)modP的键值对,计算交叉注意力:
Ap,t=softmax(dkQpK(p−t)modPT)V(p−t)modP
最终通过P轮通信完成全局注意力计算,通信复杂度为O(N)而非O(N2)。
StreamingLLM基于"注意力汇聚"现象,保留初始ksink个tokens作为锚点,配合大小为w的滑动窗口:
KVCache={KV1:ksink}∪{KVt−w+1:t}
注意力计算时,初始tokens获得的注意力权重满足:
i=1∑ksinkαi≈0.2∼0.3
即使这些tokens在语义上并不重要。这种设计使得模型能够稳定处理400万tokens的序列,速度比基线方法快22.2倍。
附录:核心算法
A. 自注意力机制的梯度推导与优化
考虑自注意力的前向传播:
A=softmax(dkQKT),Y=AV
我们需要推导损失函数L关于Q、K、V的梯度。首先定义中间变量:
S=dkQKT∈Rn×n
对于softmax函数,其Jacobian矩阵为:
∂Sik∂Aij=Aij(δjk−Aik)
其中δjk是Kronecker delta函数。
通过链式法则,损失关于S的梯度为:
∂S∂L=∂Y∂L⋅∂A∂Y⋅∂S∂A
展开计算:
∂Sij∂L=k∑∂Yik∂LVjkAij−k,l∑∂Yik∂LVlkAilAij
简化为矩阵形式:
∂S∂L=A⊙(∂Y∂LVT−diag(A⋅(∂Y∂LVT)))
进而得到关于Q和K的梯度:
∂Q∂L=dk1∂S∂LK
∂K∂L=dk1∂S∂LTQ
值矩阵的梯度更直接:
∂V∂L=AT∂Y∂L
B. 旋转位置编码(RoPE)的数学性质
定理1:RoPE保持内积的相对位置依赖性。
证明:对于两个位置m和n的向量q和k,应用RoPE后的内积为:
⟨fq(q,m),fk(k,n)⟩=qTRΘT(m)RΘ(n)k
由于旋转矩阵的性质RΘT(m)=RΘ(−m),我们有:
RΘT(m)RΘ(n)=RΘ(n−m)
将向量q和k分解为2维子空间:
q=[q0,q1,...,qd/2−1]T
对于每个2维子空间i,内积计算为:
⟨fq(qi,m),fk(ki,n)⟩=qiT(cos((n−m)θi)sin((n−m)θi)−sin((n−m)θi)cos((n−m)θi))ki
展开得:
=qi,0ki,0cos((n−m)θi)+qi,0ki,1sin((n−m)θi)−qi,1ki,0sin((n−m)θi)+qi,1ki,1cos((n−m)θi)
=(qi,0ki,0+qi,1ki,1)cos((n−m)θi)+(qi,0ki,1−qi,1ki,0)sin((n−m)θi)
这表明内积仅依赖于相对位置(n−m),证毕。
C. DPO损失函数的变分推导
从RLHF的目标函数出发:
J(πθ)=Ex∼D,y∼πθ[R(x,y)]−βDKL[πθ∣∣πref]
KL散度可以展开为:
DKL[πθ∣∣πref]=Ey∼πθ[logπref(y∣x)πθ(y∣x)]
将目标函数重写为:
J(πθ)=Ey∼πθ[R(x,y)−βlogπref(y∣x)πθ(y∣x)]
对πθ求变分导数并令其为零:
δπθ(y∣x)δJ=R(x,y)−βlogπref(y∣x)πθ(y∣x)−β−λ(x)=0
其中λ(x)是保证∑yπθ(y∣x)=1的拉格朗日乘数。
解得最优策略:
π∗(y∣x)=πref(y∣x)exp(βR(x,y)−λ(x))
归一化条件给出:
y∑π∗(y∣x)=1⇒eλ(x)/β=y∑πref(y∣x)exp(βR(x,y))=Z(x)
因此:
π∗(y∣x)=Z(x)1πref(y∣x)exp(βR(x,y))
反解奖励函数:
R(x,y)=βlogπref(y∣x)π∗(y∣x)+βlogZ(x)
对于偏好对(yw,yl),Bradley-Terry模型给出:
P(yw≻yl)=exp(R(x,yw))+exp(R(x,yl))exp(R(x,yw))=σ(R(x,yw)−R(x,yl))
代入奖励函数表达式(Z(x)项抵消):
P(yw≻yl)=σ(βlogπref(yw∣x)π∗(yw∣x)−βlogπref(yl∣x)π∗(yl∣x))
最大似然估计给出DPO损失函数:
LDPO=−E[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]
D. Mixture of Experts的负载均衡优化
考虑MoE层的前向传播,输入x∈Rd,有N个专家{Ei}i=1N,门控网络G。
标准MoE的输出为:
y=i=1∑NGi(x)Ei(x)
其中门控权重通过softmax计算:
G(x)=softmax(Wgx+bg)
为了实现稀疏激活,使用Top-K操作:
G~(x)=KeepTopK(G(x),k)
负载均衡的目标是使每个专家的期望负载相等。定义专家i的负载为:
Li=Ex∼D[1[i∈TopK(G(x),k)]]
理想情况下,Li=k/N 对所有i成立。
引入辅助损失促进均衡:
Laux=α⋅N⋅i=1∑Nfi⋅Pi
其中fi是批次中路由到专家i的token比例:
fi=∣B∣1x∈B∑1[i∈TopK(G(x),k)]
Pi是专家i的平均路由概率:
Pi=∣B∣1x∈B∑Gi(x)
这个损失函数鼓励fi和Pi都接近1/N,从而实现负载均衡。
E. Flash Attention的数值稳定性分析
Flash Attention使用在线softmax算法避免数值溢出。对于输入序列分块Qi,Kj,Vj,计算过程如下:
定义局部最大值和指数和:
mij=max(QiKjT/dk)
ℓij=∑exp(QiKjT/dk−mij)
增量更新规则保证数值稳定性:
minew=max(miold,mij)
ℓinew=emiold−minewℓiold+emij−minewℓij
Oinew=ℓinewemiold−minewℓioldOiold+ℓinewemij−minewℓijO~ij
定理2:Flash Attention的输出与标准注意力在数值上等价。
证明:考虑完整的注意力计算:
A=softmax(S)=∑jexp(Sj)exp(S)
使用log-sum-exp技巧:
logj∑exp(Sj)=m+logj∑exp(Sj−m)
其中m=maxjSj。Flash Attention通过分块计算维护:
- 全局最大值:mglobal=maxi,jmij
- 全局指数和:ℓglobal=∑i,jexp(mij−mglobal)ℓij
- 加权输出:Oglobal=ℓglobal1∑i,jexp(mij−mglobal)ℓijOij
通过归纳法可证明这与标准计算等价,且避免了指数溢出。
F. 梯度累积与混合精度训练的理论分析
在大模型训练中,由于内存限制,常使用梯度累积技术。设批量大小为B,累积步数为K,则有效批量大小为Beff=K⋅B。
标准SGD更新:
θt+1=θt−η∇θL(θt;Beff)
梯度累积更新:
gk=K1∇θL(θt;Bk),k=1,...,K
θt+1=θt−ηk=1∑Kgk
定理3:在凸优化条件下,梯度累积与标准SGD收敛速度相同。
证明:假设损失函数L是L-光滑的,即:
∣∣∇L(θ1)−∇L(θ2)∣∣≤L∣∣θ1−θ2∣∣
对于梯度累积,期望梯度为:
E[k=1∑Kgk]=E[∇θL(θt;Beff)]
方差分析:
Var[k=1∑Kgk]=K1Var[∇θL(θt;B)]
这与使用大批量Beff的方差相同,因此收敛速度一致。
混合精度训练使用FP16计算和FP32主权重。损失缩放防止梯度下溢:
Lscaled=s⋅L
gFP16=∇θLscaled=s⋅∇θL
更新规则:
θFP32=θFP32−η⋅sgFP16
动态损失缩放通过监控梯度溢出自适应调整s:
st+1={st⋅ρupst/ρdownif no overflow for N stepsif overflow detected
典型参数:ρup=2,ρdown=2,N=2000。
G. 量化技术的信息论分析
权重量化可以从率失真理论角度分析。设原始权重W∈Rm×n,量化后权重W^,量化函数Q。
均方量化误差:
D=E[∣∣W−W^∣∣F2]=i,j∑(Wij−W^ij)2
对于k-bit均匀量化:
W^ij=s⋅round(sWij)
其中量化步长:
s=2k−1max(W)−min(W)
量化信噪比(SQNR):
SQNR=10log10(E[(W−W^)2]E[W2])≈6.02k+4.77−20log10(μWσW)
AWQ(Activation-aware Weight Quantization)通过重要性加权优化量化:
W^minx∈D∑∣∣f(x;W)−f(x;W^)∣∣2
定义权重重要性:
Iij=Ex∼D[∣∣∣∣∣∂Wij∂f(x;W)∣∣∣∣∣]
AWQ保持top-p%重要权重为全精度:
W^ij={WijQk(Wij)if Iij>τpotherwise
实验表明,保持0.1%的关键权重可以将量化损失减少75%。
H. 长序列注意力的计算复杂度优化
标准自注意力的计算和内存复杂度分析:
- 计算复杂度:O(n2⋅d)
- 内存复杂度:O(n2+n⋅d)
其中n是序列长度,d是隐藏维度。
线性注意力近似:
通过核函数近似,可以将复杂度降至线性:
Attention(Q,K,V)≈ϕ(Q)∑jϕ(Kj)ϕ(Q)(ϕ(K)TV)
其中ϕ:Rd→Rr是特征映射。
使用随机特征:
ϕ(x)=r1[cos(w1Tx),sin(w1Tx),...,cos(wrTx),sin(wrTx)]
复杂度降至O(n⋅r⋅d),其中r≪n。
稀疏注意力模式:
定义注意力掩码M∈{0,1}n×n,稀疏度ρ=∣∣M∣∣0/n2。
局部窗口注意力:
Mij=1[∣i−j∣≤w]
复杂度:O(n⋅w⋅d)
跨步注意力:
Mij=1[imods=0 or jmods=0]
复杂度:O(n2/s⋅d)
组合模式通过并集实现:
M=Mlocal∪Mstrided∪Mrandom
BigBird证明了这种组合可以保持完整注意力的表达能力。
评论(0)