nn.Module
lacks the semantic of being distributed across multiple devices and running distributed operators
Manually managing DTensor
and Tensor
within a nn.Module
in distributed settings is painful and error-prone.
DModule (Distributed Module)
provides a single-device abstraction for multiple-device nn.Module
and empowers user to write distributed training/inference code as if on a single device (i.e., SPMD)
DModule
unifies Module-level Tensor Parallelism and Sequence Parallelism by transparently handling distributed logic under the hood:
Tensor
to DTensor
within a nn.Module
DTensor
sharding and resharding during forward and backwardDTensor
via Module-level API parallelize_module()
with given sharding_plan
sharding_plan
to be either:
deferred_init()
(i.e., initialize with Fake
Tensor without allocating memory, then shard Fake Tensor with TP, and then materialize only a shard of Tensor
on device)APEX
)parallelize_module
?veScale DModule
is inspired by PyTorch's parallelize_module
, but is developed with explicit Module-level abstraction with complete features for our production usage.
veScale DModule
extends PyTorch parallelize_module
with extra features as below:
Example of MLP
:
APIs can be found in <repo>/vescale/dmodule/api.py
More examples can be found under <repo>/test/dmodule/*.py