大模型原理--参数高效微调(LoRA篇)
1.概述
LoRA(Low-Rank Adaptation)是一种高效且广泛使用的参数高效微调方法,由微软研究院于 2021 年提出。因其训练成本低、适配能力强、推理无额外开销等优势,已成为当前大语言模型监督微调(SFT)中最广泛使用的技术。

2.数学原理
2.1 矩阵基础
(1)矩阵的秩(Rank):矩阵中列向量(或行向量)线性独立的最大个数。
(2)低秩矩阵:设矩阵W∈Rd×k,若W的秩满足rank(W)<<min(d,k),则称W为低秩矩阵。
(3)低秩分解:设矩阵W∈Rd×k,若W的秩为r,则存在:A∈Rd×r,B∈Rr×k,使得W=AB。
2.2 LoRA 基于实践观察:在全量微调(Full Fine-tuning)过程中,权重增量形成的矩阵是低秩矩阵。
W = W0+∆W,其中∆W是低秩矩阵
根据低秩分解,所以W = W0+AB。但是在实践过程中,一般∆W的秩r取 4、8 或 16等用于近似分解,所以W≈W0+AB。
3. 在训练过程中,LoRA 完全冻结原始权重W0,仅对新增的低秩矩阵A和B进行梯度更新。大幅减少了需要更新的参数量,同时也避免了对大规模模型权重的直接修改,使微调过程更加轻量、高效。
4. 在推理阶段,∆W可以直接合并回W0,简单、高效。
5. 代码实践
5.1 插入位置
LoRA 通常插入的位置是模型的线性层。最常用的是对注意力层的 q_proj 和 v_proj 插入 LoRA,原因如下:①Query 和 Value 对任务语义最敏感;②仅插这两处即可接近全参微调性能;③参数和显存开销最小。如需更强表达能力,也可扩展至 k_proj、o_proj 或 FFN 层,但会增加成本。
5.2 在工程中实现中,通常会额外加入两个关键组件:
(1)缩放系数(α)
为了控制 LoRA 增量在训练初期的影响力,并在不同秩下保持数值稳定性,通常会在增量上加入缩放系数,使前向计算变为:W = W0+(α/r)BA
(2) LoRA Dropout
为提升泛化能力,减轻小数据集上的过拟合,通常会对LoRA Layer的输入进行dropout。
5.3 代码实现
import math
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
def __init__(self,original_linear: nn.Linear,r: int = 8,alpha: float = 16.0,dropout: float = 0.0):
"""
LoRA 包装器,作用于 nn.Linear 层
:param original_linear: 预训练好的 nn.Linear
:param r: 低秩分解的秩
:param alpha: 缩放因子
:param dropout: 可选 dropout,放在 BA 之后(常用 0.0)
"""
super().__init__()
self.r = r
self.alpha = alpha
# 保留原层,但冻结它的权重
self.original_linear = original_linear
for param in self.original_linear.parameters():
param.requires_grad = False
out_features, in_features = original_linear.weight.shape
# A: (in_features, r) ,用 Kaiming 初始化
self.A = nn.Parameter(torch.zeros(in_features, r))
nn.init.kaiming_normal_(self.A, a=math.sqrt(5))
# B: (r, out_features) ,初始化为零,保证微调前和原模型输出结果相同
self.B = nn.Parameter(torch.zeros(r, out_features))
# 可选 dropout
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
# 记录缩放因子
self.scaling = self.alpha / self.r
def forward(self, x):
# 原始输出
orig_out = self.original_linear(x)
# LoRA 修改量:先 A 后 B,然后缩放
lora_out = (x @ self.A) @ self.B # (batch, ..., in) -> (..., r) -> (..., out)
lora_out = self.scaling * lora_out
lora_out = self.dropout(lora_out)
return orig_out + lora_out
def merge_weights(self):
"""将 LoRA 旁路融合进原权重(用于推理加速)"""
device = self.A.device
# 计算融合后的权重
delta_W = (self.A @ self.B.T) * self.scaling # (in, out)
self.original_linear.weight.data += delta_W.T # nn.Linear.weight: (out, in)
# 融合后可选择清除 A, B 以节省显存(可选)
# self.A = None; self.B = None
- 点赞
- 收藏
- 关注作者
评论(0)