nn.Module lacks the semantic of being distributed across multiple devices and running distributed operators
Manually managing DTensor and Tensor within a nn.Module in distributed settings is painful and error-prone.
DModule (Distributed Module) provides a single-device abstraction for multiple-device nn.Module and empowers user to write distributed training/inference code as if on a single device (i.e., SPMD)
DModule unifies Module-level Tensor Parallelism and Sequence Parallelism by transparently handling distributed logic under the hood:
Tensor to DTensor within a nn.ModuleDTensor sharding and resharding during forward and backwardDTensor via Module-level API parallelize_module() with given sharding_plansharding_plan to be either:
deferred_init() (i.e., initialize with Fake Tensor without allocating memory, then shard Fake Tensor with TP, and then materialize only a shard of Tensor on device)APEX)parallelize_module?veScale DModule is inspired by PyTorch's parallelize_module, but is developed with explicit Module-level abstraction with complete features for our production usage.
veScale DModule extends PyTorch parallelize_module with extra features as below:
Example of MLP:
APIs can be found in <repo>/vescale/dmodule/api.py
More examples can be found under <repo>/test/dmodule/*.py