深度学习的分布式训练与集合通信(三)
【摘要】 本文将会介绍一些更高阶的并行方式,如序列并行(SP),上下文并行(CP),混合序列并行Ulysess,ZeRO系列并行优化策略,完全分片数据并行(FSDP)。并且,在文章最后将汇总所有介绍过的并行方案与它们的通信模式,帮助读者初步建立起分布式训练与集合通信的知识结构体系。
本专题介绍常见的深度学习分布式训练的并行策略和背后使用到的集合通信操作,希望能帮助读者理解分布式训练的原理,以及集合通信之于分布式训练的重要性和必要性。鉴于篇幅限制,将拆分成三个部分展开讲述。
在上两回中,我们介绍了DP,PP,TP,EP等多种并行策略及其通信模式,详情请参见深度学习的分布式训练与集合通信(一)和深度学习的分布式训练与集合通信(二)。第三部分将会介绍一些更高阶的并行方式,如序列并行(SP),上下文并行(CP),混合序列并行Ulysess,ZeRO系列并行优化策略,完全分片数据并行(FSDP)。并且,在文章最后将汇总所有介绍过的并行方案与它们的通信模式,帮助读者初步建立起分布式训练与集合通信的知识结构体系。
序列并行/上下文并行/混合序列并行
序列并行可以看作是先前介绍过的数据并行在Transformer大语言模型下的延伸和扩展,它是更加细粒度的数据并行。大家知道,Transformer语言类模型所处理的数据对象是序列(Sequence),一个batch中包含多个序列,一个序列中包含多个token,每个token由一个向量表示。序列并行(Sequence Parallelism, SP)即为将一个序列切开,分片段到多个节点上并行处理的策略,其主要目的是可以摆脱单卡存储限制,训练超长上下文的大模型。
根据现有文献描述,序列并行可以被部署在Transformer模型的两个的阶段中,
- 一个是Attention阶段,最初来自Colossal-AI,Megatron-LM后来也提出了相似概念,将之称为上下文并行(Context Parallelism,CP),下文我们将之统称为Attention阶段的序列并行;除此之外,还有DeepSpeed提出的混合序列并行Ulysess,也在Attention阶段部分使用到了序列并行;
- 另一个是LayerNorm与Dropout阶段,来自Megatron-LM。
虽然都是序列并行,但由于处于不同的运算阶段,它们的行为以及影响是不同的。前者的目的主要是减少数据存储压力,打破模型输入序列长度(sequence length)的限制;而后者则是为了与LayerNorm与Dropout前后相邻阶段的张量并行(TP)搭配使用,减少存储压力。
下面详细介绍这两种序列并行及其通信模式。
1)Attention阶段的序列并行
输入token矩阵I中的每一列代表一个token,如图中红色箭头所示。所以序列并行也可以看作是矩阵I的列并行——将矩阵I中不同token的列向量拆分开并行计算,以减少单节点上针对序列长度的存储压力。
下面我们来看,将输入输出矩阵I、O按照列(token)的维度拆分开放在两个节点上并行计算(即序列并行SP2)对以上这5个矩阵乘法造成的影响。
总结一下就是,Attention阶段的序列并行,在模型训练的前向传播中,主要涉及到的集合通信操作有二,一是计算注意力矩阵A′时对于所有序列并行节点上矩阵K的AllGather操作,二是计算输出矩阵O时对于所有序列并行节点上矩阵V的AllGather操作,如下图所示。值得提及,对于Causal Attention,不是所有的K和V的分块都需要被每个节点用到,这使得通信不均衡效率低。解决方案一般是将token做互补分组,这里不做细节展开,感兴趣读者可以参见论文Striped Attention: Faster Ring Attention for Causal Transformers。
前向传播中K与V的AllGather,对应反向传播中K与V梯度的ReduceScatter,其原因可以用我们在系列推文(二)中的列并行级联原理来解释。
以K为例,Attention阶段的SP中涉及到K的两次矩阵乘法(2)(4),其实可以看作是两级列并行的级联:第一级是矩阵I的列并行,第二级是矩阵Q的列并行。列并行级联之间的通信模式为前向AllGather,反向ReduceScatter,如下图。V同理。
前向传播AllGather(图中的MM指代矩阵乘法):
反向传播ReduceScatter:
以上就是Attention阶段序列并行的基本原理了,以下再做四点补充:
- SP与TP的对比
- Ring Self-Attention
- Megatron-LM的 CP与Colossal-AI的SP的差别
- DeepSpeed的Ulysess方案
第一,Attention阶段的序列并行与TP的对比:Attention阶段除了SP(属于数据并行),还有一种通用的并行方式是TP(属于模型并行),如下图。TP将Attention阶段的计算按照Multi-Head中Head的维度分开,涉及到的通信主要是正向一次的各Head输出矩阵的AllReduce和反向一次的输入矩阵梯度的AllReduce;而SP是按照输入数据的序列长度的维度分的,涉及到的通信主要是正向一次的矩阵的AllGather与反向一次的矩阵梯度的ReduceScatter。值得提及,SP在Multi-Head与Single-Head情况下的通信行为一致,由于Head没有被分开,不涉及跨Head的通信行为。
图片来源:论文Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
第二,Attention阶段的序列并行的Ring Self-Attention:Ring Self-Attention是一种Attention阶段SP的常见实现方法。最早由Colossal-AI提出,如下图。它的本质思想是将Ring AllGather的通信步骤拆分开,用P2P的Send和Receive替换,与计算任务做流水掩盖。具体来说,即在分步完成K(V)矩阵AllGather操作的时候,同步进行注意力矩阵A(输出矩阵O)的计算。当计算一个A/O分块的时间大于传输一个K/V分块的时间,就不会引入额外的通信开销。
图片来源:论文Sequence Parallelism: Long Sequence Training from System Perspective
第三,Megatron-LM提出的 CP与Colossal-AI提出的SP的差别:CP在SP之后被提出,其主体思想与Colossal-AI的SP一致,都是Attention阶段的序列并行。其主要优化点在于,CP进一步利用了Flash Attention的方法对注意力矩阵进行了分块计算。结合上述Ring Attention的计算通信流水部署方案,CP一次传输一组KV矩阵,得到一个分块的输出矩阵O,最后再整合,降低存储与通信开销。
图片来源:论文Context Parallelism for Scalable Million-Token Inference
具体说来,之前介绍的SP在计算注意力矩阵A′时,需要按列的维度进行softmax计算,这也就要求只有当所有(列)的K被收集到以后(AllGather完成),单个节点上才能完成注意力矩阵从A到A′的softmax计算,所以KV矩阵的传输是分成两个阶段进行的;而CP用了Flash Attention则可以移除这个限制,分块计算注意力矩阵A′,进而得到输出矩阵分块O最后再整合,其KV矩阵的传输是一个阶段完成的,减低了通信开销,并且大幅降低了访存。
除此之外,CP还对大模型中常见的causal attention中序列不同位置的序列段负载不均衡的问题进行了优化。CP在将序列分成多份的同时,会考虑到不同序列段计算量的不一样,进行对称互补组队,再按组分配到不同节点中并行处理,这样可以提升硬件利用率,对计算与通信做更好的流水掩盖。
第四,除了SP与CP,Attention阶段的序列并行还有DeepSpeed提出的Ulysess方案,如下图。Ulysess将Attention阶段的计算分成了两个阶段:第一阶段计算QKV矩阵,与SP/CP相同,都是按照序列长度的维度进行切分;第二阶段计算AA′O矩阵,与SP/CP不同,Ulysess按照Multi-Head中Head的维度进行了切分,每个节点计算出全部序列的部分Head的结果,最后再整合结果回到序列切分的维度上去。
图片来源:论文Deepspeed Ulysses: System Optimizations For Enabling Training Of Extreme Long Sequence Transformer Models
所以,Ulysess需要在第二阶段开始前,对不同节点上的计算得到的QKV三矩阵(分属不同序列段,所有Head)依据第二阶段各节点负责的Head进行三次(QKV各一次)AlltoAll的通信操作,下图示意了矩阵Q的AlltoAll操作;以及在二阶段结束后,对不同节点上计算得到的O矩阵(分属不同Head,所有序列段)依据各节点负责的序列段再进行一次AlltoAll的通信操作。Ulysess这里的AlltoAll是负载均衡的(所有节点收发相同的数据量),不像之前介绍过的EP可能出现专家负载不均,从而导致AlltoAll劣化为AlltoAllV。
值得注意的是,使用Ulysess有一个前提,就是Multi-Head的Head个数需要能够整除序列并行数。
综上所有讨论,Attention的序列并行涉及到的主要通信模式有两种,一种是SP与CP的AllGather + ReduceScatter,采用Ring Attention的方法将通信分步骤打散成多个P2P的Send和Receive操作来实现,与分块计算进行流水掩盖;另一种是Ulysess的AlltoAll,使用在Attention一阶段序列并行与二阶段Multi-Head并行(可以看做TP)级联的交接处,以及二阶段的结束处。
2)LayerNorm与Dropout阶段的序列并行
Megatron-LM提出的SP,特指LayerNorm与Dropout阶段的序列并行,如下图,是一种与Transformer模型中其他阶段部署的TP结合使用的并行策略,简称SP+TP。其达到的主要效果是——在不增加通信量的同时(一次AllReduce等价为一次ReduceScatter+AllGather),减少各节点对中间激活值的存储需求以及降低LayerNorm与Dropout的计算量。
图片来源:论文Reducing Activation Recomputation in Large Transformer Models
Megatron-LM的SP的理论基础是:LayerNorm与Dropout阶段的计算对于序列长度的维度是独立的,即1)LayerNorm是对每个token的特征向量进行归一化,而不是对整个序列的所有token一起归一化;2)Dropout是以元素级别的方式进行的,即对每个token向量的每个元素独立地应用Dropout,不会跨token共享Dropout掩码,而是为每个元素独立生成掩码。
于是顺理成章的,LayerNorm与Dropout阶段不需要所有节点都包含完整序列的信息,可以把序列切成多份放在不同的节点上做LayerNorm与Dropout的计算,这样既降低了计算量,也减轻了节点的存储压力。
对于通信,SP+TP与之前纯TP相比通信总量维持不变,只是通信的模式发生了改变。具体来说,就是对于一个Transformer层,纯TP需要前反向各做两次AllReduce,而SP+TP需要前反向各做两次AllGather+ReduceScatter。解释一下这背后原因,我们知道Attention/FFN阶段的TP使用的都是列并行+行并行的级联形式,如下图。
图片来源:论文Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
列+行并行的数据流动示意图如下,详细推导见本系列的第二篇推文。
这种级联形式的主要特点就是:列并行的输入需要完整copy输入矩阵到每个节点;行并行后各节点的输出只是partial sum,要得到完整输出,需要将各节点的输出进行element-wise的求和;列行并行级联的中间不涉及通信操作。结合这些特点与LayerNorm/Dropout阶段SP的对数据的切分方式,可以知道在Attention/FFN阶段之前需要插入AllGather来收集所有序列片段的数据,而在Attention/FFN阶段之后只需要插入ReduceScatter来分片求和输出的partial sum,以在不同节点上获得不同序列片段对应的完整输出。
上面介绍的是前向传播过程,因为一层transformer layer中涉及两个列+行并行级联模块(Attention和FFN),在每个模块之前都需要插入AllGather,在每个模块之后都需要插入ReduceScatter,所有SP+TP一共是两次AllGather+ReduceScatter。
对于反向传播,我们知道AllGather和ReduceScatter在前反向传播上是对偶关系——前向AllGather对应反向ReduceScatter,前向ReduceScatter对应反向AllGather,所以SP+TP在反向传播时同样也需要做两次AllGather+ReduceScatter。
至此,序列并行就介绍完了,下表总结了典型三种序列并行的通信模式,所有序列并行的通信模式在HCCL API中均已支持。
ZeRO系列并行/完全分片数据并行
ZeRO的全称是Zero Redundancy Optimizer零冗余优化,由DeepSpeed提出,用来解决大模型训练中内存开销大的问题,其主旨思想是去除冗余存储,以及用通信换存储。以普通数据并行(DP)为基线(对应下图中的第一行),ZeRO按存储需求由多到少,衍生出了三种进阶版的并行策略,简称ZeRO-1,ZeRO-2,ZeRO-3(对应下图中的第二三四行)。
完全分片数据并行(Fully Sharded Data Parallel, FSDP)是pytorch1.11中的特性,其本质就是ZeRO-3。下面我们逐一来看。
在基线DP中(图中的第一行),每个节点都会维护一套完整的模型。这包含模型参数本身(Parameters),参数的梯度(Gradients),以及优化器的状态(Optimizer States)。优化器状态一般包含多个参数(如momentum和variance)并以高精度(如FP32)存储。所以优化器状态需要的内存空间是远大于参数和梯度的,对应图中绿色块大于蓝色和橙色块。而ZeRO做的就是将这套完整的模型一点一点切分到所有并行的节点上去。
为方便对比,先把基线DP的训练流程列出来:
- 各节点收到完整的模型Parameters以及各自批次的训练数据,完成前向和反向传播,得到各自的Gradients;
- 通过AllReduce整合各节点的Gradients得到聚合好的完整的Gradients;
- 各节点使用完整的Gradients和完整的Optimizer States更新完整的模型parameters。
ZeRO-1(图中的第二行)首先对存储需求量最大的优化器状态开刀。通过切分优化器状态到各节点,ZeRO-1将内存需求降低了4倍,单卡通信数据量提高1.5倍(对比基线)。这背后的训练流程如下:
- 各节点收到完整的模型Parameters以及各自批次的训练数据,完成前向和反向传播,得到各自的Gradients;(同基线)
- 通过AllReduce整合各节点的Gradients得到聚合好的完整的Gradients;(同基线)
- 各节点使用完整的Gradients和部分的OptimizerStates更新部分的模型Parameters;
- 通过AllGather收集各节点的部分更新了的Parameters得到完整更新的模型Parameters。
ZeRO-2(图中的第三行)在ZeRO-1的基础上,接着对梯度开刀。通过切分梯度、优化器状态到各节点,ZeRO-2将内存需求降低了8倍,单卡通信数据量基本不变(对比基线)。这背后的训练流程如下:
- 各节点收到完整的模型Parameters以及各自批次的训练数据,完成前向和反向传播,得到各自的Gradients;(同基线)
- 通过ReduceScatter整合各节点的Gradients得到聚合好的部分的Gradients,随传随把不属于各自节点维护的Gradients丢弃;
- 各节点使用部分的Gradients和部分的OptimizerStates更新部分模型Parameters;
- 通过AllGather收集各节点的部分更新了的Parameters得到完整更新的模型Parameters。
ZeRO-3 / FSDP(图中的第四行)在ZeRO-2的基础上,继续对模型参数本身开刀。通过切分模型参数、梯度、和优化器状态到各节点,ZeRO-3 / FSDP将内存需求降低了64倍,单卡通信数据量提高1.5倍(对比基线)。这背后的训练流程如下:
- 各节点收到部分的模型Parameters以及各自批次的训练数据;
- 做前向传播,各节点通过AllGather收集完整的模型Parameters,随算随把不属于自己维护的模型Parameters丢弃;
- 做反向传播,各节点通过AllGather收集完整的模型Parameters,随算随把不属于自己维护的模型Parameters丢弃;
- 反向传播完,各节点得到各自的Gradients;
- 通过ReduceScatter整合各节点的Gradients得到聚合好的部分的Gradients,随传随把不属于各自节点维护的Gradients丢弃;
- 各节点使用部分的Gradients和部分的OptimizerStates更新部分模型Parameters(无需再对其做AllReduce操作)。
总结一下,ZeRO和FSDP是DP的进阶版,它们在DP的基础上加入了模型并行的思路。通过切分模型的优化器状态、梯度、模型参数本身,大幅降低数据并行时模型的内存开销。ZeRO以通信代存储,尽可能的让不同节点不存重复的数据,不做重复的计算,需要时通过通信从其他节点获得。ZeRO在LLM大模型训练中被广泛使用,其涉及到的主要通信操作为ReduceScatter和AllGather,所有ZeRO和FSDP并行的通信模式在HCCL API中均已支持。
总结
本系列推文详细介绍了大模型分布式训练的各种并行策略,聚焦分析了这些并行策略中使用到的通信操作,帮助读者理解分布式训练的原理与通信模式,并且HCCL API对这些通信操作进行了全方位的支持,如需了解详细信息,请查阅HCCL API列表。
参考材料
https://www.zhihu.com/question/637961859
https://cloud.tencent.com/developer/article/2424244
https://zhuanlan.zhihu.com/p/5502876106
https://zhuanlan.zhihu.com/p/618865052
https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
作者其他文章
评论(0)