大模型原理--分布式训练策略之ZeRO

举报
剑指南天 发表于 2026/05/06 20:41:40 2026/05/06
【摘要】 在标准数据并行中,每个设备都保存完整的模型状态,导致跨设备的冗余存储,并严重限制了可训练模型的规模(模型的规模受到显卡的限制,无法训练大规模参数的模型)。ZeRO的主要是为解决在标准数据并行中跨设备的冗余存储的问题。

1.概述

GPT-2 揭示了规模效应(Scaling Effect):随着模型参数与数据规模的增长,模型性能在多种任务上持续提升。基于此,大语言模型(Large Language Model, LLM)的参数量和训练数据规模必然持续增长,分布式训练已成为现代大模型训练体系中的核心技术。

本文主要介绍一个对传统数据并行深度增强的分布式训练策略:ZeRO(Zero Redundancy Optimizer)。

在标准数据并行中,每个设备都保存完整的模型状态,导致跨设备的冗余存储,并严重限制了可训练模型的规模(模型的规模受到显卡的限制,无法训练大规模参数的模型)。ZeRO的主要是为解决在标准数据并行中跨设备的冗余存储的问题。

2. ZeRO

ZeRO的核心思想是:分片存储、按需加载。它不再让每块 GPU 都完整保存一整套模型状态,而是将模型状态切分成多个分片,分布式地存储在不同的 GPU 上;在执行某一层的前向、反向传播或参数更新时,再通过高效的集体通信,按需、临时地从其他设备拉取当前计算所需的分片参数,并在用完后立即释放。

(1)分片的策略

训练过程中,需要保存和维护的模型状态主要由以下三部分组成:模型参数,梯度,优化器状态。其中优化器状态显存占用大于梯度显存占用,模型参数通常与模型梯度显存占用相当

针对上述部分模型状态,ZeRO 将分片策略划分为三个逐级递进的阶段:①ZeRO-1:优化器状态分片;②ZeRO-2:优化器状态 + 梯度分片;③ZeRO-3:优化器状态 + 梯度分片 + 模型参数分片。

(2)具体流程

假设模型结构是16层的transformer block模型,分片策略选择ZeRO-3,如下:

①第一步和数据并行策略一样,让不同 GPU 并行处理不同的数据子集。

②将数据并行策略的模型状态进行切分,分配给各个GPU。

切分后的模型状态分布情况如下,每个GPU保存一部分模型状态。

③前向传播时,现将GPU 0上的M0多层的模型参数(FP16)广播到其他GPU。

并行计算每个GPU上M0多层的激活值(红框内,FP16),然后销毁已经广播的模型参数(FP16)。

然后M1,M2,M3多层依次执行上述相同的操作,得到每个GPU上面的模型的所有激活值(红框内,FP16)和Loss。因为接下来要反向传播,所以保留M3多层在各个GPU上的模型参数(FP16)。

④反向传播时,并行计算M3多层在各个GPU上面的梯度(FP16)。

将其余GPU上面的M3多层梯度数据(FP16)发送到GPU 3后,销毁其余GPU M3多层的梯度(FP16)、模型参数(FP16)和激活值(FP16),GPU 3计算梯度均值(FP16),保存。下图红框内就是保存的M3多层的梯度数据(FP16)。

然后M2,M1,M0依次执行上述相同的操作(注意:需要先广播对应的的模型参数(FP16)),计算出各个GPU对应层段的平均梯度(FP16)。

⑤参数更新时,每个GPU现将各个多层的梯度数据(FP16)复制成一份(FP32),保存。计算Variance(历史梯度,FP32)和Momentum(动量FP32),保存。然后计算新的模型参数(FP32),保存。

每个GPU把各个多层的新的模型参数复制一份(FP16),保存。各个GPU模型参数的更新结束,完成一个step。

3. 总结:将模型状态分配给各个GPU,解除了跨设备存储冗余,显著减少了每块 GPU 的显存占用,所以训练的模型规模远超单卡显存限制。又可以通过配置分片策略,平衡显存与通信开销。所以ZeRO是训练超大规模模型不可或缺的基础技术。

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。