大模型原理--混合精度计算
1.概述
大模型混合精度训练,本质上是一种用精度换空间,再用技巧保精度的策略:通过在模型训练的不同环节,灵活使用不同精度的浮点数(如 FP32、FP16、BF16 等),在保证模型最终性能的前提下,大幅提升训练速度、降低显存占用。
2.混合精度训练主要有三个关键机制
①主权重副本 (Master Weights)
框架会一直维护一份 FP32 高精度的模型权重。在每轮训练迭代中,会将这份权重复制一份转换为 FP16 去做高速的前向和反向传播计算。当得到 FP16 的梯度后,会将其转换回 FP32,并用来更新那份高精度的主权重。这确保了参数更新足够精细,不会丢失信息。
②自动损失缩放 (Loss Scaling)
梯度下溢:在训练后期,很多参数的梯度值会变得非常小。如果这个小梯度的值小于FP16所能表示的最小值,在FP16格式下它就会被直接表示为0。这被称为梯度下溢。一旦梯度变为0,对应的参数就永远无法再更新了,这会严重影响模型收敛和最终效果。解决梯度下溢的方案非常简单:先将计算出的损失值乘以一个大的系数(如 1024),反向传播时梯度也跟着被放大,这样就不会小到变为 0。等优化器要更新参数时,再将梯度除以相同的系数还原回去即可。
动态Loss Scaling:固定一个缩放因子不够灵活,可能在某些阶段会不够大。现在普遍使用动态损失缩放。先将计算出的损失值乘以一个大的系数(如 1024),检查梯度,如果没有发生“上溢”,就尝试增大缩放因子(如乘以2),让保护更充分。如果检测到上溢,说明这次缩放过头了,就会跳过本次更新,并减小缩放因子(如除以2)。
③算子黑白名单:不是所有计算都适合低精度。像矩阵乘法这类计算密集型算子能用低精度大幅加速;而像 Softmax、归一化(Normalization)、Loss 计算这类对精度敏感的算子,则必须留在高精度下执行。框架会自动按“黑白名单”来管理这些计算。
3. 总结
传统训练使用 FP32 精度,但它的计算量和内存需求极大。混合精度训练的核心就是使用 FP16 或 BF16 这类低精度格式进行计算,它们能带来最直接的好处:计算更快和显存占用近乎减半。“混合”精度的精髓在于:用高精度(FP32)保证关键信息的准确性,用低精度(FP16/BF16)加速大部分运算和节省显存。
- 点赞
- 收藏
- 关注作者
评论(0)