In veScale, we provide two optimizers for Optimizer Parallel:
DistributedOptimizer
BasicOptimizer
DistributedOptimizer
DistributedOptimizer
is a ZeRO 2+ optimizer. Similar to the original ZeRO2, it parallelizes model gradient and optimizer states along Data Parallel dimension. Differently, it further parallelizes model parameters virtually but not physically.
DistributedOptimizer
is primarily inherited from Megatron-LM's DistributedOptimizer for its performance and mostly due to the lacking of ZeRO2 optimizer in native PyTorch. We extend and enhance DistributedOptimizer
with extra features:
convert between Tensor
and DTensor
support online resharding of optimzier state
In DistributedOptimizer
, the model gradients and optimizer states are sharded along Data Parallel dimension in each gradient Bucket of Gradient Buffer (see DDP
for more details), where each DP rank only manages its own shard of gradient, generates its own shard of optimizer states, and updates its own shard of parameters.
The flow of DistributedOptimizer
is as follows:
__init__
function. Then the optimizer's param_groups
is replaced with the Sharded Parameter.Receive Reduced Gradient resulting from ReduceScatter
per Gradient Bucket in DDP
Attach Reduced Gradient (main_grad
of each original parameter) to the Sharded Parameter
Run the actual optimizer.step()
to generate Optimizer State of each shard and updates Sharded Parameter with Reduced Gradient
Copy the updated Sharded Parameter to a specific parameter buffer and get ready for AllGather
communication to restore the full parameters
AllGather
, the Gradient Buffer of DDP
is reused as the communication buffer for AllGather
.AllGather
with the forward computation in the next iteration for hiding communication overhead, similar to gradient ReduceScater
overlap with backward computationAPIs can found in: <repo>/vescale/optim/distributed_optimizer.py
.
More examples can found in: <repo>/test/parallel/ddp_optim/test_doptimizer.py
.
BasicOptimizer
BasicOptimizer
is a not ZeRO optimizer but a simple optimizer that works like Data Parallel which replicates parameters, gradients, and optimizer states along Data Parallel dimension.
BasicOptimizer
itself is nothing but a simple wrapper that wraps given optimizer instance with utilities for veScale DTensor
, DModule
, and DDP
:
convert between Tensor
and DTensor
recover flattened gradient from DDP
trigger gradient synchronization of DModule
(e.g., for Sequence Parallel)
APIs can be found in: <repo>/vescale/optim/base_optimizer.py
.
Examples can be found in <repo>/test/parallel/ddp_optim/test_ddp.py
.
DDP
?The compatibility of the above optimizers with DDP
is as follows:
BasicOptimizer | DistributedOptimizer | |
---|---|---|
DDP | yes | yes |
NO DDP | yes | no |