torch.Tensor lacks the semantic of being distributed across multiple devices and running distributed operators
Manually managing torch.Tensor in distributed settings is painful and error-prone, as it demands the manual handling of the sharded storage on each device, the collective communication among devices, and the operator kernel split across devices, all with great care.
DTensor (Distributed Tensor) provides a single-device abstraction for multiple-device torch.Tensor and empowers user to write distributed training/inference code as if on a single device (i.e., SPMD)
DTensor transparently handles all distributed logic under the hood (sharded storage on each device, the collective communication among devices, and the operator kernel split across devices)
DTensor is implemented by a wrapper class on torch.Tensor with a meta data DTensorSpec describing:
which multiple devices (DeviceMesh) is distributed upon
DeviceMesh("cuda", [0, 1])DeviceMesh("cuda", [[0, 1], [2, 3]])how is DTensor placed (Placement) on the DeviceMesh:
there are three main Placement:
Shard(<tensor_dim>): DTensor's <tensor_dim> is sharded on the DeviceMeshReplicate: DTensor is replicated on the DeviceMeshPartial: DTensor is a partial product on the DeviceMesh with pending sum (AllReduce) to be a total productwhere a list of Placement is needed to define the placements of a DTensor:
placements = [Shard(1)] means DTensor's tensor dim #1 is sharded along DeviceMesh's dim #0 (i.e., the #0 element in the list)
placements = [Shard(1), Shard(0)] means DTensor's tensor dim #1 is sharded along DeviceMesh's dim #0 and DTensor's tensor dim #0 is sharded along DeviceMesh's dim #1
placements = [Shard(1), Replicate()] means DTensor's tensor dim #1 is sharded along DeviceMesh's dim #0 and DTensor's rest tensor dim #0 is replicated along DeviceMesh's dim #1
what is the global tensor shape & stride (TensorMeta) of this DTensor
DTensor operators (e.g., torch.add) are implemented by ShardingPropagator which propagates placements from input to output for each operator with pre-registered sharding rules and strategies
veScale is a PyTorch-native framework rooted in PyTorch DTensor
veScale DTensor extends and enhances the PyTorch DTensor for our production standard with extra features as below:
enabled "correct random ops" under abitrary sharding and uneven sharding, i.e., always guarantee random op sharded on multi device is equal to random op on a single device.
enabled DTensor support for third-party plug-in ops (e.g., APEX) by unleashing DTensor.data_ptr and handling asynchronous collective tensors (e.g., in from_local, to_local, redistribute)
make implicit _Partial to explicit Partial placement for optimized initialization, output, and checkpoint (with an extra dispatch mode)
enabled DTensor ops that were not implemented in PyTorch for forward or/and backward:
argmaxargmintopk_unique2scatter_scatterselectaliasindex_put_index_putindex_add__scaled_dot_product_flash_attention_scaled_dot_product_efficient_attentionexpand_asone_hotwhereEmbedding in vocabular parallelsupport uneven sharding in conversion between DTensor and torch.Tensor
decoupled special op handling that bypasses DTensor dispatching (_bypass_for_dispatch)
enabled patching before (_pre_patch_for_dispatch) and after (_post_patch_for_dispatch) DTensor dispatch, for adding user's custom dispatching logic without coupling original dispatch logic
enabled short-cut for ops to bypass sharding propagation entirely (_bypass_for_sharding_prop):
bypassed tensor_meta propagation for ops:
Replicate, by using local output Tensor's tensor_metatensor_meta propagation under dtensor/ops (e.g., conv, slice, copy, clone, bucketize, t)recompute_tensor_meta_list (e.g., clone, native_dropout, nll_loss_forward)enabled DeviceMesh on meta device type
enabled DeviceMesh initialization from an existing processs group
enabled DeviceMesh being split into a list of sub meshes
disabled redistributed input:
VESCALE_DISABLE_REDISTRIBUTE), as we don't expect uncontrollable resharding and implicit communication in DTensor dispatch for production. (Ideally, all resharding and communication should be controlled by the end users.)support deferred initiailization and materialization for DTensor with extended torchdistx
[experimental] developed InterleavedShard placement to support merged QKV in MHA
[experimental] extreme performance with C++ DTensor
[experimental] extreme performance with dispatching-free DTensor
Example of matmul:
APIs can be found under <repo>/vescale/dtensor/api.py
More examples can be found under <repo>/test/dtensor/*/*.py
Original examples can be found in PyTorch DTensor.
-- Register DTensor "Ops" for Sharding Propagation!
Sharding propagation is an important step in DTensor dispatch. It is responsible for inferring the output sharding info (i.e., DTensorSpec) from the input sharding info at each operator. So that the all ops of an entire model can be expressed in DTensor.
There are two ways to register sharding propagation, namely:
They're the same thing intrinsically. But the difference between the rule-based and strategy-based way is that the former only needs to consider the current input DTensorSpec while the later requires enumerating all valid (input DTensorSpec, output DTensorSpec) pair for a single op.
The pros of the rule-based way is the ease of use, while pros of the strategy-based way is having all possible combinations of input-output sharding -- a context info necessary for automatically selecting the best strategy for input-output sharding (e.g., the one with the minimal DTensor redistribution cost).
It's recommended to use strategy-based way to register sharding propagation. But if you encounter a really complex custom op, rule-based way might be the better choice.
Ideally, DTensor should provide single-device abstraction even for random ops (e.g. dtensor.randn, nn.Dropout, and <any random ops>), i.e., random value generated on single device should be identical to collective of random shard on multiple devices.
PyTorch DTensor (i.e., OffsetBasedRNGTracker) does not produce the random values on multiple devices identical to single GPU execution for random operators (e.g. dtensor.randn, nn.Dropout, and <any random ops>).
The key problem lies in that the CUDA random numbers are not generated "sequentially" and cannot be simply offsetted by rank ids, but instead are generated "simultaneously" by multiple CUDA threads and only be sharded by CUDA thread ids!
In veScale, we introduce a ThreadBasedRNGTracker for correcting the RNG states across different GPUs, enabling generation of correct DTensor that are identical to the ones from single GPUs for any random ops.
To use the feature, build and install a patched PyTorch of veScale and set the environment variable VESCALE_SINGLE_DEVICE_RAND=1.
Whenever invoking a randomized operation on a DTensor, ThreadBasedRNGTracker passes its sharding info to the C++/Cuda side of PyTorch through the RNG state.
This resolves the issue that PyTorch DTensor's OffsetBasedRNGTracker does not produce the output identical to single GPU executions.
For example, consider generating x = torch.rand(4) given the current random seed and
a global offset. In Cuda's RNG implementation, random numbers are accessed via a triple
(seed, thread id, offset).
On a single GPU, 4 GPU threads is created and the i-th thread fills the entry x[i]
with rand(seed, i, offset). That is, we have
After the execution of torch.rand(4), the global offset increments by 4, which is the
granularity of cuda's RNG offsets.
The global offset increments by the size of the randomness used in each thread, rounded up to the nearest multiple of 4. For instance, if 1000 GPU threads is used to generate 7000 random numbers, each thread takes 7 random numbers from Cuda RNG and the global offset increases by 8 afterward.
However, using OffsetBasedRNGTracker, it outputs a different tensor given 2 GPUs.
Furthermore, after the execution, the global offset increments by 8 instead of 4.
To resolve the issue, each physical thread of each GPU should fill the entry using the thread id as if there is only one GPU. In the previous example, the output should be
And after the execution, the global offset should increment by 4. This can be done if we pass the sharding info into Cuda functions that generate these outputs.
We would like to acknowledge the assistance of and collaboration with the PyTorch DTensor team.