In veScale, we provide two optimizers for Optimizer Parallel:
DistributedOptimizer
BasicOptimizer
DistributedOptimizerDistributedOptimizer is a ZeRO 2+ optimizer. Similar to the original ZeRO2, it parallelizes model gradient and optimizer states along Data Parallel dimension. Differently, it further parallelizes model parameters virtually but not physically.
DistributedOptimizer is primarily inherited from Megatron-LM's DistributedOptimizer for its performance and mostly due to the lacking of ZeRO2 optimizer in native PyTorch. We extend and enhance DistributedOptimizer with extra features:
convert between Tensor and DTensor
support online resharding of optimzier state
In DistributedOptimizer, the model gradients and optimizer states are sharded along Data Parallel dimension in each gradient Bucket of Gradient Buffer (see DDP for more details), where each DP rank only manages its own shard of gradient, generates its own shard of optimizer states, and updates its own shard of parameters.
The flow of DistributedOptimizer is as follows:
__init__ function. Then the optimizer's param_groups is replaced with the Sharded Parameter.Receive Reduced Gradient resulting from ReduceScatter per Gradient Bucket in DDP
Attach Reduced Gradient (main_grad of each original parameter) to the Sharded Parameter
Run the actual optimizer.step() to generate Optimizer State of each shard and updates Sharded Parameter with Reduced Gradient
Copy the updated Sharded Parameter to a specific parameter buffer and get ready for AllGather communication to restore the full parameters
AllGather, the Gradient Buffer of DDP is reused as the communication buffer for AllGather.AllGather with the forward computation in the next iteration for hiding communication overhead, similar to gradient ReduceScater overlap with backward computationAPIs can found in: <repo>/vescale/optim/distributed_optimizer.py.
More examples can found in: <repo>/test/parallel/ddp_optim/test_doptimizer.py.
BasicOptimizerBasicOptimizer is a not ZeRO optimizer but a simple optimizer that works like Data Parallel which replicates parameters, gradients, and optimizer states along Data Parallel dimension.
BasicOptimizer itself is nothing but a simple wrapper that wraps given optimizer instance with utilities for veScale DTensor, DModule, and DDP:
convert between Tensor and DTensor
recover flattened gradient from DDP
trigger gradient synchronization of DModule (e.g., for Sequence Parallel)
APIs can be found in: <repo>/vescale/optim/base_optimizer.py.
Examples can be found in <repo>/test/parallel/ddp_optim/test_ddp.py.
DDP?The compatibility of the above optimizers with DDP is as follows:
BasicOptimizer | DistributedOptimizer | |
|---|---|---|
DDP | yes | yes |
NO DDP | yes | no |