Skip to content

Commit

Permalink
Fix device id
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Sep 22, 2024
1 parent 2870da5 commit 1074ecf
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions tutorials/developer_api_guide/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _(func, types, args, kwargs):
class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs) -> None:
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
self.linear = torch.nn.Linear(in_features, out_features, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
Expand Down Expand Up @@ -144,10 +144,15 @@ def main():
# To make sure different ranks create the same module
torch.manual_seed(5)

# Get rank and device
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

# Original model
proj_up = M(1024, 2048)
proj_dn = M(2048, 1024)
example_input = 100 * torch.randn(128, 1024, device="cuda")
proj_up = M(1024, 2048).to(device)
proj_dn = M(2048, 1024).to(device)
example_input = 100 * torch.randn(128, 1024, device=device)
y = proj_dn(proj_up(example_input))

# Quantize the model
Expand All @@ -157,8 +162,6 @@ def main():
print("Quantization works!")

# Create a device mesh
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
dist.init_process_group(backend="nccl")
mesh = dist.init_device_mesh("cuda", (world_size,))

Expand Down

0 comments on commit 1074ecf

Please sign in to comment.