Skip to content

Commit

Permalink
Use DTensor.from instead of distribute_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Sep 11, 2024
1 parent b4b356c commit 55f36f0
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions tutorials/developer_api_guide/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Test #
########
if __name__ == "__main__":
# To make sure different ranks create the same module
torch.manual_seed(5)

m = M()
example_inputs = (100 * torch.randn(128, 1024),)
m(*example_inputs)
example_input = 100 * torch.randn(128, 1024)
m(example_input)


import os
import torch
from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor
from torch.distributed._tensor import DTensor, Replicate, Shard
import torch.distributed as dist

# initialize a fake process group
store = torch.testing._internal.distributed.fake_pg.FakeStore()
dist.init_process_group(
backend="fake",
world_size=2,
rank=0,
store=store,
)
mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),))
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
dist.init_process_group(backend="nccl")
mesh = dist.init_device_mesh("cuda", (world_size,))

# Shard this tensor over the mesh by sharding `big_tensor`'s 0th dimension over the 0th dimension of `mesh`.
quantized_weight = to_my_dtype_tp(m.linear.weight)
orig_weight = m.linear.weight
quantized_weight = to_my_dtype_tp(orig_weight)
print("quantized weight:", quantized_weight)
quantized_weight_dtensor = distribute_tensor(quantized_weight, mesh, [Shard(dim=0)])
print("quantized weight dtensor:", quantized_weight_dtensor)
# Number of rows per rank
n_local_rows = orig_weight.size(0) // world_size
# TODO: add support for aten.slice.Tensor
quantized_shard = quantized_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
print("quantized shard:", quantized_shard)
# Construct DTensor from local shard
quantized_dtensor = DTensor.from_local(quantized_shard, device_mesh, [Shard(0)])
print("quantized dtensor:", quantized_dtensor)

# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
quantized_weight_dtensor, requires_grad=False
quantized_dtensor, requires_grad=False
)

m(*example_inputs)
# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(
example_input, mesh, [Replicate()]
)
print("input dtensor:", input_dtensor)

m(input_dtensor)

0 comments on commit 55f36f0

Please sign in to comment.