分布式预训练中的数据并行是使用广泛且加速性能良好的并行方法。各个数据切片能完全解耦,只需在每个 mini-batch 结束时进行梯度的 all-reduce。数据并行分为中心化方式(如 pytorch 里的 DataParallel)和无中心化方式(如 pytorch 里的 DistributedDataParallel)。这两种方式最大的区别在于 gradient 和 reduce 计算过程。DataParallel 需在 forward 之后把所有输出 gather 到 0 号卡上,计算完 loss 之后再 scatter 到各个设备上,然后做 backward 独立计算 gradient,最后搜集 gradient 到 0 号卡,forward 和 backward 间需插入一次通信。DistributedDataParallel 则是每张卡独立做 forward 和 backward,然后对各卡的 gradient 做 all-reduce,forward 和 backward 间无需通信。此外,ZeRO 的出发点是优化数据并行中的显存占用,因为在数据并行中,每个 device 上都有完整的权重、梯度和优化器状态信息,较为冗余。
数据并行(Data Parallel)是使用最为广泛的并行方法,加速性能非常好,原因是各个数据切片可以做到完全解耦,只需要在最后每个mini-batch结束的时候做一下梯度的all-reduce既可。数据并行可以分为中心化方式的和无中心化方式的,对应于pytorch里面的DataParallel和DistributedDataParallel这两种方式最大的区别是gradient和reduce计算过程DataParallel是要在forward之后把所有输出gather到0号卡上,计算完loss之后再scatter到各个设备上,然后做backward独立计算gradient,最后搜集gradient到0号卡。因此需要在forward和backward间插入一次通信DistributedDataParallel是每张卡独立的做forward和backward,然后对各卡的gradient做all-reduce。因此forward和backward间无需通信
ZeRO的出发点是希望优化数据并行里显存占用。因为在数据并行里面,每个device上都有完整的权重信息,梯度信息和优化器状态信息,这个其实是比较冗余的。