分布式预训练中的流水线并行是一种在模型参数量太大一张卡不能完全放下时采用的切分方法。具体而言,沿着模型的拓扑序将其切分成 p 段,每段为一个 stage。将大小为 N 的 mini-batch 进一步切分为 m 个大小为 M 的 micro-batch(N = m·M),这些 micro-batch 依次进入上述 p 个 stage。
在流水线并行切分后,每个 stage 的耗时会发生变化。假设 1 指出:算力为 1 的节点,处理完整模型的 1 个 micro-batch,前向和反向耗时分别是 t_f 和 t_b。基于此有推论 1:算力为 1/p 的节点,处理完整模型的 1 个 micro-batch,前向和反向耗时分别是 p·t_f 和 p·t_b;推论 2:算力为 1 的节点,处理 1/p 模型的 1 个 micro-batch,前向和反向耗时分别是 t_f/p 和 t_b/p;推论 3:算力为 1/p 的节点,处理 1/p 模型的 1 个 micro-batch,前向和反向耗时分别是 t_f 和 t_b。
从耗时情况来看,理论上界是显存无限大,不需要 pipeline 并行,一把梭直接对 mini-batch 的样本做前向和反向,耗时正比于样本数量,b_best 耗时为 t_best = m·(t_f + t_b)。而最朴素的串行方式,每个 micro-batch 串行逐个做前向和反向,一个 micro-batch 的耗时是 p·(t_f + t_b),那么 1 个 mini-batch 的耗时 t_sequential 是 t_sequential = mp·(t_f + t_b),可见耗时是理论上界的 p 倍,存在大量计算资源闲置空载,硬件利率用很低。
最后小结,3D 并行包括数据并行(DP)、流水线并行(PP)和模型并行(TP)。DP 计算和通信效率友好,但权重显存不友好;PP 要求 mini-batch 里 batch size 足够大以掩盖流水线带来的 overhead,batch size 过大则会增大激活显存占用;TP 权重显存友好,但计算和通信效率不友好,通信量要求大。ZeRO 针对数据并行显存占用大的问题提出优化,Alpa 借鉴 AI 编译器思路对 3D 并行建模并用自动化搜索方式得到并行策略。
因此根据推论3可知,所有算力处理整个模型的耗时,跟$$1/p$$的算力处理$$1/p$$段模型的耗时,是一致的。下面开始讨论各种策略下的耗时情况。[heading2]理论上界[content]我们先考虑理想最优情况。此时显存无限大,不需要pipeline并行,一把梭直接对mini-batch的样本做前向和反向,耗时正比于样本数量。由假设1不难算出$$b_\text{best}$$耗时为$$t_{\text{best}}=m\cdot(t_f+t_b)$$[heading2]朴素串行[content]下面考虑最朴素的串行方式,类似于CPU的单周期处理器时代。每个micro-batch串行逐个做前向和反向,如下所示这里蓝色是前向过程,绿色是反向过程,粗略认为反向过程是前向的2倍耗时。此时由推论3知道一个micro-batch的耗时是$$p\cdot(t_f+t_b)$$,那么1个mini-batch的耗时$$t_\text{sequential}$$是$$t_\text{sequential}=mp\cdot(t_f+t_b)$$可见耗时是理论上界的$p$倍,其中有大量的计算资源在闲置空载,硬件利率用很低
当模型参数量太大,一张卡不能完全放下的情况下,就必须对模型进行切分了,流水线并行(Pipeline Parallel)就是一种切分方法。具体来说沿着模型的拓扑序,切分成$$p$$段,每一段为一个stage,因此可以形成逻辑上相互串联的$$p$$个stage将大小为$$N$$的mini-batch进一步切分为$$m$$个大小为$$M$$的micro-batch,因此$$N=m\cdot M$$这些micro-batch依次进入上述$$p$$个stage我们首先研究一下流水线并行切分之后,每个stage的耗时跟原来相比会有什么变化,如下所示假设1:算力为$$1$$的节点,处理完整模型的1个micro-batch,前向和反向耗时分别是$$t_f$$和$$t_b$$推论1:算力为$$1/p$$的节点,处理完整模型的1个micro-batch,前向和反向耗时分别是$$p\cdot t_f$$和$$p\cdot t_b$$推论2:算力为$$1$$的节点,处理$$1/p$$模型的1个micro-batch,前向和反向耗时分别是$$t_f/p$$和$$t_b/p$$推论3:算力为$$1/p$$的节点,处理$$1/p$$模型的1个micro-batch,前向和反向耗时分别是$$t_f$$和$$t_b$$
最后小结一下,本文主要有4部分内容分布式通信原语包括了点对点通信和集合通信的方法。其中集合通信包括了一对多的Broadcast和Scatter,多对一的Reduce和Gather,多对多的AllReduce、AllGather、ReduceScatter和AllToAll3D并行包括了数据并行(DP)、流水线并行(PP)和模型并行(TP)DP的优势是计算和通信效率都很友好,但是权重的显存不友好,每张卡都有一份PP的问题是要求mini-batch里面batch size足够大才能掩盖住流水线带来的overhead。batch size如果过大,会增大激活显存的占用TP的优势是权重显存非常友好,没有冗余。但是计算和通信效率不友好,通信量要求很大,在超出了一个island的时候性能下降很快。ZeRO针对数据并行显存占用大的问题,借鉴了Parameter Server的思路,提出了ZeRO-1,ZeRO-2和ZeRO-3的优化。其中ZeRO-2让每张卡只维护一部分的梯度和优化器状态,显存占用减少到原来的$1/8$,通信带宽保持不变Alpa鉴了AI编译器的思路,对3D并行进行建模,用自动化搜索的方式得到了仅次于手工最优的并行策略PS:由于笔者小A并没有亲手撸过上述内容的所有细节,大部分是通过研究代码和精读优秀文章的方式bottom-up总结而来,本质上是个拾人牙慧的知识搬运工,所以终究是纸上谈兵。因此希望各方有实际经验的大佬猛锤,思维碰撞才生火花,真理越辩越明。如果想了解transformer在NLP/多模态/AIGC的算法知识,分布式训练的知识,以及如何在TVM上做PTQ量化和部署,可以关注我aaronxic哟~