diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index 87c94b896a..69539d1e21 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -41,32 +41,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Test # ######## if __name__ == "__main__": + # To make sure different ranks create the same module + torch.manual_seed(5) + m = M() - example_inputs = (100 * torch.randn(128, 1024),) - m(*example_inputs) + example_input = 100 * torch.randn(128, 1024) + m(example_input) import os - import torch - from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor + from torch.distributed._tensor import DTensor, Replicate, Shard import torch.distributed as dist + # initialize a fake process group - store = torch.testing._internal.distributed.fake_pg.FakeStore() - dist.init_process_group( - backend="fake", - world_size=2, - rank=0, - store=store, - ) - mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),)) + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + dist.init_process_group(backend="nccl") + mesh = dist.init_device_mesh("cuda", (world_size,)) + # Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`. - quantized_weight = to_my_dtype_tp(m.linear.weight) + orig_weight = m.linear.weight + quantized_weight = to_my_dtype_tp(orig_weight) print("quantized weight:", quantized_weight) - quantized_weight_dtensor = distribute_tensor(quantized_weight, mesh, [Shard(dim=0)]) - print("quantized weight dtensor:", quantized_weight_dtensor) + # Number of rows per rank + n_local_rows = orig_weight.size(0) // world_size + # TODO: add support for aten.slice.Tensor + quantized_shard = quantized_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] + print("quantized shard:", quantized_shard) + # Construct DTensor from local shard + quantized_dtensor = DTensor.from_local(quantized_shard, device_mesh, [Shard(0)]) + print("quantized dtensor:", quantized_dtensor) + # Replace parameter in module m.linear.weight = torch.nn.Parameter( - quantized_weight_dtensor, requires_grad=False + quantized_dtensor, requires_grad=False ) - m(*example_inputs) + # We need to turn inputs into DTensor form as well -- just a format change + input_dtensor = DTensor.from_local( + example_input, mesh, [Replicate()] + ) + print("input dtensor:", input_dtensor) + + m(input_dtensor) +