大语言模型的核心算法——简要解析

举报
DuHz 发表于 2025/09/10 21:26:32 2025/09/10
【摘要】 大语言模型的核心算法——简要解析 Transformer架构的数学本质与演进 自注意力机制的核心原理Transformer架构的灵魂在于自注意力机制,它允许模型在处理序列中的每个元素时,动态地关注序列中的所有其他位置。从数学角度看,自注意力的计算过程可以表达为:Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{s...

大语言模型的核心算法——简要解析

Transformer架构的数学本质与演进

自注意力机制的核心原理

Transformer架构的灵魂在于自注意力机制,它允许模型在处理序列中的每个元素时,动态地关注序列中的所有其他位置。从数学角度看,自注意力的计算过程可以表达为:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

当输入序列XRn×dmodelX \in \mathbb{R}^{n \times d_{model}}进入自注意力层时,首先通过三个不同的线性变换生成查询(Query)、键(Key)和值(Value)矩阵:

Q=XWQ,K=XWK,V=XWVQ = XW^Q, \quad K = XW^K, \quad V = XW^V

其中WQ,WK,WVRdmodel×dkW^Q, W^K, W^V \in \mathbb{R}^{d_{model} \times d_k}是可学习的参数矩阵。查询矩阵QQ代表当前位置想要获取的信息,键矩阵KK代表每个位置能够提供的信息特征,而值矩阵VV则是实际要传递的信息内容。通过计算QQKK的点积,模型得到了一个注意力分数矩阵,表示每个位置对其他所有位置的关注程度。

这里的缩放因子dk\sqrt{d_k}起着关键作用。随着维度dkd_k的增加,点积的方差会线性增长,可能导致softmax函数进入饱和区域,梯度变得极小。具体来说,假设qqkk的分量是独立同分布的随机变量,均值为0,方差为1,则它们点积的方差为dkd_k。通过除以dk\sqrt{d_k}进行缩放,确保了注意力分数保持在合理的范围内,维持了训练的稳定性。

多头注意力(Multi-Head Attention)进一步扩展了这一机制。通过并行运行多个注意力头,每个头关注不同的表示子空间,模型能够同时捕获多种类型的依赖关系:

MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,...,\text{head}_h)W^O

其中每个注意力头计算为:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

这里WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_{model} \times d_k}WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_{model} \times d_k}WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{model} \times d_v},而WORhdv×dmodelW^O \in \mathbb{R}^{hd_v \times d_{model}}将所有头的输出整合。这种设计让模型能够从多个角度理解输入序列的结构和语义。

位置编码的革命性演进

原始Transformer使用正弦位置编码来注入序列的顺序信息:

PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)

PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

这种编码方式虽然简单有效,但在处理超长序列时存在局限性。

2024年,旋转位置嵌入(RoPE)已经成为事实标准,被LLaMA、GPT-NeoX等主流模型广泛采用。RoPE的核心创新在于通过旋转矩阵编码绝对位置信息,同时自然地融入相对位置依赖。对于位置mm的token embedding xm\mathbf{x}_m,RoPE应用旋转变换:

f(xm,m)=Rmxmf(\mathbf{x}_m, m) = \mathbf{R}_m \mathbf{x}_m

其中旋转矩阵Rm\mathbf{R}_m定义为:

Rm=(cosmθ0sinmθ000sinmθ0cosmθ00000cosmθ1sinmθ100sinmθ1cosmθ1)\mathbf{R}_m = \begin{pmatrix} \cos m\theta_0 & -\sin m\theta_0 & 0 & 0 & \cdots \\ \sin m\theta_0 & \cos m\theta_0 & 0 & 0 & \cdots \\ 0 & 0 & \cos m\theta_1 & -\sin m\theta_1 & \cdots \\ 0 & 0 & \sin m\theta_1 & \cos m\theta_1 & \cdots \\ \vdots & \vdots & \vdots & \vdots & \ddots \end{pmatrix}

频率参数θi=100002i/d\theta_i = 10000^{-2i/d}。RoPE的优雅之处在于,当计算两个位置mmnn的注意力分数时:

f(q,m),f(k,n)=qTRmTRnk=qTRnmk\langle f(\mathbf{q}, m), f(\mathbf{k}, n) \rangle = \mathbf{q}^T \mathbf{R}_m^T \mathbf{R}_n \mathbf{k} = \mathbf{q}^T \mathbf{R}_{n-m} \mathbf{k}

点积结果自然包含了相对位置信息(nm)(n-m)。这使得模型能够处理训练时未见过的序列长度,实现了真正的长度外推能力。

相比之下,ALiBi(Attention with Linear Biases)采用了更加直接的方法,通过在注意力分数中添加线性偏置来编码位置信息:

AttentionALiBi(Q,K,V)=softmax(QKTdkmij)V\text{Attention}_{ALiBi}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} - m \cdot |i-j|\right)V

其中mm是头特定的斜率参数,ij|i-j|是位置差的绝对值。

Flash Attention与硬件协同优化

2024年7月发布的Flash Attention v3代表了算法与硬件协同设计的巅峰。标准注意力计算的时间复杂度为O(N2d)O(N^2d),空间复杂度为O(N2)O(N^2),其中NN是序列长度,dd是特征维度。Flash Attention通过分块计算和内存优化,将空间复杂度降低到O(N)O(N)

其核心思想是将注意力矩阵分块计算,避免存储完整的N×NN \times N注意力矩阵。对于输入序列,Flash Attention将其分为大小为Br×BcB_r \times B_c的块,并使用以下更新规则:

mijnew=max(mijold,m~ij)m_{ij}^{new} = \max(m_{ij}^{old}, \tilde{m}_{ij})

lijnew=emijoldmijnewlijold+em~ijmijnewl~ijl_{ij}^{new} = e^{m_{ij}^{old} - m_{ij}^{new}} l_{ij}^{old} + e^{\tilde{m}_{ij} - m_{ij}^{new}} \tilde{l}_{ij}

Oijnew=lijoldlijnewemijoldmijnewOijold+l~ijlijnewem~ijmijnewO~ijO_{ij}^{new} = \frac{l_{ij}^{old}}{l_{ij}^{new}} e^{m_{ij}^{old} - m_{ij}^{new}} O_{ij}^{old} + \frac{\tilde{l}_{ij}}{l_{ij}^{new}} e^{\tilde{m}_{ij} - m_{ij}^{new}} \tilde{O}_{ij}

其中mm是最大值,ll是指数和,OO是输出。这种增量更新方式允许逐块处理,显著减少了内存访问。

在FP16精度下,Flash Attention v3达到了740 TFLOPs/s的性能,相当于H100理论峰值的75%利用率。它还支持FP8低精度计算,在保持数值稳定性的同时达到了1.2 PFLOPs/s的性能。

分组查询注意力(GQA)作为另一项重要创新,通过将查询头分组共享键值对,在效率与质量之间找到了完美平衡。假设有HH个查询头和GG个键值组(G<HG < H),则第ii个查询头使用第iG/H\lfloor i \cdot G/H \rfloor组的键值对:

GQAi=Attention(Qi,KiG/H,ViG/H)\text{GQA}_i = \text{Attention}(Q_i, K_{\lfloor i \cdot G/H \rfloor}, V_{\lfloor i \cdot G/H \rfloor})

以Llama 2 70B为例,使用8个组可以将KV缓存减少87.5%,而模型性能几乎不受影响。

主流模型的架构创新与技术突破

GPT-4的稀疏专家混合架构

根据已知信息,GPT-4采用了Mixture of Experts(MoE)架构,总参数约1.8万亿,包含16个专家网络,每个约1110亿参数。关键的是,每次前向传播仅激活2个专家,实际使用约2200亿参数。

MoE架构的核心在于路由机制。对于输入token xx,路由网络计算每个专家的选择概率:

G(x)=TopK(softmax(Wgx+ϵ),k)G(x) = \text{TopK}(\text{softmax}(W_g \cdot x + \epsilon), k)

其中WgW_g是路由权重,ϵ\epsilon是噪声项用于探索,TopK\text{TopK}选择概率最高的kk个专家。最终输出为:

y=i=1kGi(x)Ei(x)y = \sum_{i=1}^k G_i(x) \cdot E_i(x)

这里Ei(x)E_i(x)是第ii个专家的输出,Gi(x)G_i(x)是对应的门控权重。

负载均衡是MoE训练的关键挑战。DeepSeek-V3通过auxiliary-loss-free策略优化了这一过程,定义负载均衡损失为:

Lbalance=αi=1N(fi1N)2\mathcal{L}_{balance} = \alpha \cdot \sum_{i=1}^N \left( f_i - \frac{1}{N} \right)^2

其中fif_i是第ii个专家的平均激活频率,NN是专家总数,α\alpha是平衡系数。

Claude的Constitutional AI训练范式

Anthropic的Constitutional AI代表了一种全新的模型对齐方法。不同于传统的RLHF需要大量人工标注,Constitutional AI通过两个阶段实现自监督学习。

在第一阶段,模型生成响应后,基于预定义的宪法原则进行自我批评和修订。这个过程可以形式化为条件生成:$$p!\left(y_{\text{revised}} \mid x,, y_{\text{initial}},, \mathcal{C}\right)
= \prod_{t=1}^{T} p!\left(y_t \mid y_{<t},, x,, y_{\text{initial}},, \mathcal{C}\right)$$

其中 C\mathcal{C} 表示宪法原则集合,yinitialy_{\text{initial}} 是初始响应,yrevisedy_{\text{revised}} 是修订后的响应。

在第二阶段,使用AI反馈(RLAIF)训练偏好模型。偏好模型的目标是学习条件分布:

P(y1y2x,C)=σ(rθ(x,y1,C)rθ(x,y2,C))P(y_1 \succ y_2 | x, \mathcal{C}) = \sigma(r_\theta(x, y_1, \mathcal{C}) - r_\theta(x, y_2, \mathcal{C}))

其中rθr_\theta是参数化的奖励函数,σ\sigma是sigmoid函数,y1y2y_1 \succ y_2表示y1y_1优于y2y_2

LLaMA 3.1的密集架构优化

Meta的LLaMA 3.1 405B虽然采用密集架构而非MoE,但通过一系列技术创新实现了卓越性能。模型使用RMSNorm替代LayerNorm:

RMSNorm(x)=xRMS(x)γ\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma

其中:

RMS(x)=1di=1dxi2\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2}

这种简化不仅提高了训练稳定性,还减少了约7%的计算开销。

SwiGLU激活函数是另一项关键创新:

SwiGLU(x)=Swish(xW1)(xW2)\text{SwiGLU}(x) = \text{Swish}(xW_1) \otimes (xW_2)

其中Swish函数定义为:

Swish(x)=xσ(βx)=x1+eβx\text{Swish}(x) = x \cdot \sigma(\beta x) = \frac{x}{1 + e^{-\beta x}}

β=1\beta = 1时,SwiGLU提供了比ReLU更平滑的梯度,其导数为:

dSwishdx=σ(βx)+βxσ(βx)(1σ(βx))\frac{d\text{Swish}}{dx} = \sigma(\beta x) + \beta x \cdot \sigma(\beta x)(1 - \sigma(\beta x))

这种平滑性有助于模型学习更复杂的表示,特别是在深层网络中。

训练技术的范式转变

从RLHF到DPO的演进

强化学习从人类反馈(RLHF)曾是大模型对齐的黄金标准。RLHF的目标是最大化期望奖励,同时限制与参考策略的偏离:

maxπθExD,yπθ(yx)[Rϕ(x,y)]βDKL[πθ(yx)πref(yx)]\max_{\pi_\theta} \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\theta(y|x)} [R_\phi(x,y)] - \beta \mathbb{D}_{KL}[\pi_\theta(y|x) || \pi_{ref}(y|x)]

通过拉格朗日乘数法,可以得到最优策略的闭式解:

π(yx)=1Z(x)πref(yx)exp(R(x,y)β)\pi^*(y|x) = \frac{1}{Z(x)} \pi_{ref}(y|x) \exp\left(\frac{R(x,y)}{\beta}\right)

其中Z(x)Z(x)是配分函数:

Z(x)=yπref(yx)exp(R(x,y)β)Z(x) = \sum_y \pi_{ref}(y|x) \exp\left(\frac{R(x,y)}{\beta}\right)

直接偏好优化(DPO)的关键洞察是,可以从最优策略的形式反推出奖励函数:

R(x,y)=βlogπ(yx)πref(yx)+βlogZ(x)R(x,y) = \beta \log \frac{\pi^*(y|x)}{\pi_{ref}(y|x)} + \beta \log Z(x)

基于Bradley-Terry模型,人类偏好的概率可以表示为:

P(ywylx)=σ(R(x,yw)R(x,yl))P(y_w \succ y_l | x) = \sigma(R(x, y_w) - R(x, y_l))

代入奖励函数的表达式并化简(注意Z(x)Z(x)项会抵消),得到DPO的损失函数:

LDPO(πθ;πref)=E(x,yw,yl)D[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{DPO}(\pi_\theta; \pi_{ref}) = -\mathbb{E}_{(x,y_w,y_l) \sim \mathcal{D}} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)} \right) \right]

这种方法消除了奖励模型训练阶段,直接优化策略参数,计算效率提升11%,内存使用减少11%。

长上下文处理的技术突破

Ring Attention通过分布式计算突破了序列长度限制。假设有PP个设备,序列长度为NN,每个设备处理N/PN/P长度的子序列。注意力计算分解为:

Attentionglobal=p=1PLocalAttentionp+CrossAttentionp(p+1)modP\text{Attention}_{global} = \bigcup_{p=1}^P \text{LocalAttention}_p + \text{CrossAttention}_{p \rightarrow (p+1) \mod P}

在环形拓扑中,设备pp在第tt轮接收来自设备(pt)modP(p-t) \mod P的键值对,计算交叉注意力:

Ap,t=softmax(QpK(pt)modPTdk)V(pt)modPA_{p,t} = \text{softmax}\left(\frac{Q_p K_{(p-t) \mod P}^T}{\sqrt{d_k}}\right) V_{(p-t) \mod P}

最终通过PP轮通信完成全局注意力计算,通信复杂度为O(N)O(N)而非O(N2)O(N^2)

StreamingLLM基于"注意力汇聚"现象,保留初始ksinkk_{sink}个tokens作为锚点,配合大小为ww的滑动窗口:

KVCache={KV1:ksink}{KVtw+1:t}\text{KVCache} = \{\text{KV}_{1:k_{sink}}\} \cup \{\text{KV}_{t-w+1:t}\}

注意力计算时,初始tokens获得的注意力权重满足:

i=1ksinkαi0.20.3\sum_{i=1}^{k_{sink}} \alpha_i \approx 0.2 \sim 0.3

即使这些tokens在语义上并不重要。这种设计使得模型能够稳定处理400万tokens的序列,速度比基线方法快22.2倍。

附录:核心算法

A. 自注意力机制的梯度推导与优化

考虑自注意力的前向传播:

A=softmax(QKTdk),Y=AVA = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right), \quad Y = AV

我们需要推导损失函数L\mathcal{L}关于QQKKVV的梯度。首先定义中间变量:

S=QKTdkRn×nS = \frac{QK^T}{\sqrt{d_k}} \in \mathbb{R}^{n \times n}

对于softmax函数,其Jacobian矩阵为:

AijSik=Aij(δjkAik)\frac{\partial A_{ij}}{\partial S_{ik}} = A_{ij}(\delta_{jk} - A_{ik})

其中δjk\delta_{jk}是Kronecker delta函数。

通过链式法则,损失关于SS的梯度为:

LS=LYYAAS\frac{\partial \mathcal{L}}{\partial S} = \frac{\partial \mathcal{L}}{\partial Y} \cdot \frac{\partial Y}{\partial A} \cdot \frac{\partial A}{\partial S}

展开计算:

LSij=kLYikVjkAijk,lLYikVlkAilAij\frac{\partial \mathcal{L}}{\partial S_{ij}} = \sum_{k} \frac{\partial \mathcal{L}}{\partial Y_{ik}} V_{jk} A_{ij} - \sum_{k,l} \frac{\partial \mathcal{L}}{\partial Y_{ik}} V_{lk} A_{il} A_{ij}

简化为矩阵形式:

LS=A(LYVTdiag(A(LYVT)))\frac{\partial \mathcal{L}}{\partial S} = A \odot \left(\frac{\partial \mathcal{L}}{\partial Y} V^T - \text{diag}\left(A \cdot \left(\frac{\partial \mathcal{L}}{\partial Y} V^T\right)\right)\right)

进而得到关于QQKK的梯度:

LQ=1dkLSK\frac{\partial \mathcal{L}}{\partial Q} = \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial S} K

LK=1dkLSTQ\frac{\partial \mathcal{L}}{\partial K} = \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial S}^T Q

值矩阵的梯度更直接:

LV=ATLY\frac{\partial \mathcal{L}}{\partial V} = A^T \frac{\partial \mathcal{L}}{\partial Y}

B. 旋转位置编码(RoPE)的数学性质

定理1:RoPE保持内积的相对位置依赖性。

证明:对于两个位置mmnn的向量q\mathbf{q}k\mathbf{k},应用RoPE后的内积为:

fq(q,m),fk(k,n)=qTRΘT(m)RΘ(n)k\langle f_q(\mathbf{q}, m), f_k(\mathbf{k}, n) \rangle = \mathbf{q}^T \mathbf{R}_\Theta^T(m) \mathbf{R}_\Theta(n) \mathbf{k}

由于旋转矩阵的性质RΘT(m)=RΘ(m)\mathbf{R}_\Theta^T(m) = \mathbf{R}_\Theta(-m),我们有:

RΘT(m)RΘ(n)=RΘ(nm)\mathbf{R}_\Theta^T(m) \mathbf{R}_\Theta(n) = \mathbf{R}_\Theta(n-m)

将向量q\mathbf{q}k\mathbf{k}分解为2维子空间:

q=[q0,q1,...,qd/21]T\mathbf{q} = [\mathbf{q}_0, \mathbf{q}_1, ..., \mathbf{q}_{d/2-1}]^T

对于每个2维子空间ii,内积计算为:

fq(qi,m),fk(ki,n)=qiT(cos((nm)θi)sin((nm)θi)sin((nm)θi)cos((nm)θi))ki\langle f_q(\mathbf{q}_i, m), f_k(\mathbf{k}_i, n) \rangle = \mathbf{q}_i^T \begin{pmatrix} \cos((n-m)\theta_i) & -\sin((n-m)\theta_i) \\ \sin((n-m)\theta_i) & \cos((n-m)\theta_i) \end{pmatrix} \mathbf{k}_i

展开得:

=qi,0ki,0cos((nm)θi)+qi,0ki,1sin((nm)θi)qi,1ki,0sin((nm)θi)+qi,1ki,1cos((nm)θi)= q_{i,0}k_{i,0}\cos((n-m)\theta_i) + q_{i,0}k_{i,1}\sin((n-m)\theta_i) - q_{i,1}k_{i,0}\sin((n-m)\theta_i) + q_{i,1}k_{i,1}\cos((n-m)\theta_i)

=(qi,0ki,0+qi,1ki,1)cos((nm)θi)+(qi,0ki,1qi,1ki,0)sin((nm)θi)= (q_{i,0}k_{i,0} + q_{i,1}k_{i,1})\cos((n-m)\theta_i) + (q_{i,0}k_{i,1} - q_{i,1}k_{i,0})\sin((n-m)\theta_i)

这表明内积仅依赖于相对位置(nm)(n-m),证毕。

C. DPO损失函数的变分推导

从RLHF的目标函数出发:

J(πθ)=ExD,yπθ[R(x,y)]βDKL[πθπref]J(\pi_\theta) = \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\theta} [R(x,y)] - \beta \mathbb{D}_{KL}[\pi_\theta || \pi_{ref}]

KL散度可以展开为:

DKL[πθπref]=Eyπθ[logπθ(yx)πref(yx)]\mathbb{D}_{KL}[\pi_\theta || \pi_{ref}] = \mathbb{E}_{y \sim \pi_\theta} \left[\log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}\right]

将目标函数重写为:

J(πθ)=Eyπθ[R(x,y)βlogπθ(yx)πref(yx)]J(\pi_\theta) = \mathbb{E}_{y \sim \pi_\theta} \left[R(x,y) - \beta \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}\right]

πθ\pi_\theta求变分导数并令其为零:

δJδπθ(yx)=R(x,y)βlogπθ(yx)πref(yx)βλ(x)=0\frac{\delta J}{\delta \pi_\theta(y|x)} = R(x,y) - \beta \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)} - \beta - \lambda(x) = 0

其中λ(x)\lambda(x)是保证yπθ(yx)=1\sum_y \pi_\theta(y|x) = 1的拉格朗日乘数。

解得最优策略:

π(yx)=πref(yx)exp(R(x,y)λ(x)β)\pi^*(y|x) = \pi_{ref}(y|x) \exp\left(\frac{R(x,y) - \lambda(x)}{\beta}\right)

归一化条件给出:

yπ(yx)=1eλ(x)/β=yπref(yx)exp(R(x,y)β)=Z(x)\sum_y \pi^*(y|x) = 1 \Rightarrow e^{\lambda(x)/\beta} = \sum_y \pi_{ref}(y|x) \exp\left(\frac{R(x,y)}{\beta}\right) = Z(x)

因此:

π(yx)=1Z(x)πref(yx)exp(R(x,y)β)\pi^*(y|x) = \frac{1}{Z(x)} \pi_{ref}(y|x) \exp\left(\frac{R(x,y)}{\beta}\right)

反解奖励函数:

R(x,y)=βlogπ(yx)πref(yx)+βlogZ(x)R(x,y) = \beta \log \frac{\pi^*(y|x)}{\pi_{ref}(y|x)} + \beta \log Z(x)

对于偏好对(yw,yl)(y_w, y_l),Bradley-Terry模型给出:

P(ywyl)=exp(R(x,yw))exp(R(x,yw))+exp(R(x,yl))=σ(R(x,yw)R(x,yl))P(y_w \succ y_l) = \frac{\exp(R(x,y_w))}{\exp(R(x,y_w)) + \exp(R(x,y_l))} = \sigma(R(x,y_w) - R(x,y_l))

代入奖励函数表达式(Z(x)Z(x)项抵消):

P(ywyl)=σ(βlogπ(ywx)πref(ywx)βlogπ(ylx)πref(ylx))P(y_w \succ y_l) = \sigma\left(\beta \log \frac{\pi^*(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi^*(y_l|x)}{\pi_{ref}(y_l|x)}\right)

最大似然估计给出DPO损失函数:

LDPO=E[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{DPO} = -\mathbb{E} \left[\log \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)}\right)\right]

D. Mixture of Experts的负载均衡优化

考虑MoE层的前向传播,输入xRdx \in \mathbb{R}^d,有NN个专家{Ei}i=1N\{E_i\}_{i=1}^N,门控网络GG

标准MoE的输出为:

y=i=1NGi(x)Ei(x)y = \sum_{i=1}^N G_i(x) E_i(x)

其中门控权重通过softmax计算:

G(x)=softmax(Wgx+bg)G(x) = \text{softmax}(W_g x + b_g)

为了实现稀疏激活,使用Top-K操作:

G~(x)=KeepTopK(G(x),k)\tilde{G}(x) = \text{KeepTopK}(G(x), k)

负载均衡的目标是使每个专家的期望负载相等。定义专家ii的负载为:

Li=ExD[1[iTopK(G(x),k)]]L_i = \mathbb{E}_{x \sim \mathcal{D}} [\mathbb{1}[i \in \text{TopK}(G(x), k)]]

理想情况下,Li=k/NL_i = k/N 对所有ii成立。

引入辅助损失促进均衡:

Laux=αNi=1NfiPi\mathcal{L}_{aux} = \alpha \cdot N \cdot \sum_{i=1}^N f_i \cdot P_i

其中fif_i是批次中路由到专家ii的token比例:

fi=1BxB1[iTopK(G(x),k)]f_i = \frac{1}{|B|} \sum_{x \in B} \mathbb{1}[i \in \text{TopK}(G(x), k)]

PiP_i是专家ii的平均路由概率:

Pi=1BxBGi(x)P_i = \frac{1}{|B|} \sum_{x \in B} G_i(x)

这个损失函数鼓励fif_iPiP_i都接近1/N1/N,从而实现负载均衡。

E. Flash Attention的数值稳定性分析

Flash Attention使用在线softmax算法避免数值溢出。对于输入序列分块Qi,Kj,VjQ_i, K_j, V_j,计算过程如下:

定义局部最大值和指数和:

mij=max(QiKjT/dk)m_{ij} = \max(Q_i K_j^T / \sqrt{d_k})

ij=exp(QiKjT/dkmij)\ell_{ij} = \sum \exp(Q_i K_j^T / \sqrt{d_k} - m_{ij})

增量更新规则保证数值稳定性:

minew=max(miold,mij)m_i^{new} = \max(m_i^{old}, m_{ij})

inew=emioldminewiold+emijminewij\ell_i^{new} = e^{m_i^{old} - m_i^{new}} \ell_i^{old} + e^{m_{ij} - m_i^{new}} \ell_{ij}

Oinew=emioldminewioldinewOiold+emijminewijinewO~ijO_i^{new} = \frac{e^{m_i^{old} - m_i^{new}} \ell_i^{old}}{\ell_i^{new}} O_i^{old} + \frac{e^{m_{ij} - m_i^{new}} \ell_{ij}}{\ell_i^{new}} \tilde{O}_{ij}

定理2:Flash Attention的输出与标准注意力在数值上等价。

证明:考虑完整的注意力计算:

A=softmax(S)=exp(S)jexp(Sj)A = \text{softmax}(S) = \frac{\exp(S)}{\sum_j \exp(S_j)}

使用log-sum-exp技巧:

logjexp(Sj)=m+logjexp(Sjm)\log \sum_j \exp(S_j) = m + \log \sum_j \exp(S_j - m)

其中m=maxjSjm = \max_j S_j。Flash Attention通过分块计算维护:

  1. 全局最大值:mglobal=maxi,jmijm_{global} = \max_{i,j} m_{ij}
  2. 全局指数和:global=i,jexp(mijmglobal)ij\ell_{global} = \sum_{i,j} \exp(m_{ij} - m_{global}) \ell_{ij}
  3. 加权输出:Oglobal=1globali,jexp(mijmglobal)ijOijO_{global} = \frac{1}{\ell_{global}} \sum_{i,j} \exp(m_{ij} - m_{global}) \ell_{ij} O_{ij}

通过归纳法可证明这与标准计算等价,且避免了指数溢出。

F. 梯度累积与混合精度训练的理论分析

在大模型训练中,由于内存限制,常使用梯度累积技术。设批量大小为BB,累积步数为KK,则有效批量大小为Beff=KBB_{eff} = K \cdot B

标准SGD更新:

θt+1=θtηθL(θt;Beff)\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t; \mathcal{B}_{eff})

梯度累积更新:

gk=1KθL(θt;Bk),k=1,...,Kg_k = \frac{1}{K} \nabla_\theta \mathcal{L}(\theta_t; \mathcal{B}_k), \quad k = 1, ..., K

θt+1=θtηk=1Kgk\theta_{t+1} = \theta_t - \eta \sum_{k=1}^K g_k

定理3:在凸优化条件下,梯度累积与标准SGD收敛速度相同。

证明:假设损失函数L\mathcal{L}LL-光滑的,即:

L(θ1)L(θ2)Lθ1θ2||\nabla \mathcal{L}(\theta_1) - \nabla \mathcal{L}(\theta_2)|| \leq L ||\theta_1 - \theta_2||

对于梯度累积,期望梯度为:

E[k=1Kgk]=E[θL(θt;Beff)]\mathbb{E}[\sum_{k=1}^K g_k] = \mathbb{E}[\nabla_\theta \mathcal{L}(\theta_t; \mathcal{B}_{eff})]

方差分析:

Var[k=1Kgk]=1KVar[θL(θt;B)]\text{Var}[\sum_{k=1}^K g_k] = \frac{1}{K} \text{Var}[\nabla_\theta \mathcal{L}(\theta_t; \mathcal{B})]

这与使用大批量BeffB_{eff}的方差相同,因此收敛速度一致。

混合精度训练使用FP16计算和FP32主权重。损失缩放防止梯度下溢:

Lscaled=sL\mathcal{L}_{scaled} = s \cdot \mathcal{L}

gFP16=θLscaled=sθLg_{FP16} = \nabla_\theta \mathcal{L}_{scaled} = s \cdot \nabla_\theta \mathcal{L}

更新规则:

θFP32=θFP32ηgFP16s\theta_{FP32} = \theta_{FP32} - \eta \cdot \frac{g_{FP16}}{s}

动态损失缩放通过监控梯度溢出自适应调整ss

st+1={stρupif no overflow for N stepsst/ρdownif overflow detecteds_{t+1} = \begin{cases} s_t \cdot \rho_{up} & \text{if no overflow for } N \text{ steps} \\ s_t / \rho_{down} & \text{if overflow detected} \end{cases}

典型参数:ρup=2\rho_{up} = 2ρdown=2\rho_{down} = 2N=2000N = 2000

G. 量化技术的信息论分析

权重量化可以从率失真理论角度分析。设原始权重WRm×nW \in \mathbb{R}^{m \times n},量化后权重W^\hat{W},量化函数QQ

均方量化误差:

D=E[WW^F2]=i,j(WijW^ij)2D = \mathbb{E}[||W - \hat{W}||_F^2] = \sum_{i,j} (W_{ij} - \hat{W}_{ij})^2

对于kk-bit均匀量化:

W^ij=sround(Wijs)\hat{W}_{ij} = s \cdot \text{round}\left(\frac{W_{ij}}{s}\right)

其中量化步长:

s=max(W)min(W)2k1s = \frac{\max(W) - \min(W)}{2^k - 1}

量化信噪比(SQNR):

SQNR=10log10(E[W2]E[(WW^)2])6.02k+4.7720log10(σWμW)\text{SQNR} = 10 \log_{10} \left(\frac{\mathbb{E}[W^2]}{\mathbb{E}[(W - \hat{W})^2]}\right) \approx 6.02k + 4.77 - 20\log_{10}\left(\frac{\sigma_W}{\mu_W}\right)

AWQ(Activation-aware Weight Quantization)通过重要性加权优化量化:

minW^xDf(x;W)f(x;W^)2\min_{\hat{W}} \sum_{x \in \mathcal{D}} ||f(x; W) - f(x; \hat{W})||^2

定义权重重要性:

Iij=ExD[f(x;W)Wij]I_{ij} = \mathbb{E}_{x \sim \mathcal{D}} \left[\left|\frac{\partial f(x; W)}{\partial W_{ij}}\right|\right]

AWQ保持top-p%p\%重要权重为全精度:

W^ij={Wijif Iij>τpQk(Wij)otherwise\hat{W}_{ij} = \begin{cases} W_{ij} & \text{if } I_{ij} > \tau_p \\ Q_k(W_{ij}) & \text{otherwise} \end{cases}

实验表明,保持0.1%的关键权重可以将量化损失减少75%。

H. 长序列注意力的计算复杂度优化

标准自注意力的计算和内存复杂度分析:

  • 计算复杂度:O(n2d)O(n^2 \cdot d)
  • 内存复杂度:O(n2+nd)O(n^2 + n \cdot d)

其中nn是序列长度,dd是隐藏维度。

线性注意力近似

通过核函数近似,可以将复杂度降至线性:

Attention(Q,K,V)ϕ(Q)(ϕ(K)TV)ϕ(Q)jϕ(Kj)\text{Attention}(Q,K,V) \approx \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)\sum_j \phi(K_j)}

其中ϕ:RdRr\phi: \mathbb{R}^d \rightarrow \mathbb{R}^r是特征映射。

使用随机特征:

ϕ(x)=1r[cos(w1Tx),sin(w1Tx),...,cos(wrTx),sin(wrTx)]\phi(x) = \frac{1}{\sqrt{r}} [\cos(w_1^T x), \sin(w_1^T x), ..., \cos(w_r^T x), \sin(w_r^T x)]

复杂度降至O(nrd)O(n \cdot r \cdot d),其中rnr \ll n

稀疏注意力模式

定义注意力掩码M{0,1}n×nM \in \{0,1\}^{n \times n},稀疏度ρ=M0/n2\rho = ||M||_0 / n^2

局部窗口注意力:

Mij=1[ijw]M_{ij} = \mathbb{1}[|i - j| \leq w]

复杂度:O(nwd)O(n \cdot w \cdot d)

跨步注意力:

Mij=1[imods=0 or jmods=0]M_{ij} = \mathbb{1}[i \mod s = 0 \text{ or } j \mod s = 0]

复杂度:O(n2/sd)O(n^2 / s \cdot d)

组合模式通过并集实现:

M=MlocalMstridedMrandomM = M_{local} \cup M_{strided} \cup M_{random}

BigBird证明了这种组合可以保持完整注意力的表达能力。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。