Skip to content

Commit

Permalink
change transpose to non-inplace op
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed Sep 19, 2024
1 parent 2668c67 commit 243ebc5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
5 changes: 2 additions & 3 deletions tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def dequantize(self, output_dtype=None):
block_size = (1, int_data.shape[-1])
if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed:
transposed = True
res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype)
res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype)
if transposed:
res = res.t()
return res
Expand Down Expand Up @@ -331,8 +331,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
elif dim == 1:
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
elif func is aten.t.default:
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])
return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type))

raise NotImplementedError(
f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported"
Expand Down
10 changes: 4 additions & 6 deletions tutorials/developer_api_guide/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@ def _(func, types, args, kwargs):
@implements(aten.t.default)
def _(func, types, args, kwargs):
tensor = args[0]
print("before transpose, ", tensor.shape)
shape = tensor.shape[::-1]
new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype)
print("after transpose:", new.shape)
return return_and_correct_aliasing(func, args, kwargs, new)

@implements(aten.addmm.default)
Expand All @@ -80,8 +78,7 @@ def _(func, types, args, kwargs):
args[1],
None
)
print("mm input tensor shape:", input_tensor.shape)
print("mm weight tensor shape:", weight_tensor.shape)
print("mm weight transposed:", weight_tensor.layout_tensor.transposed)
weight_tensor = weight_tensor.dequantize()
return aten.mm(input_tensor, weight_tensor)

Expand Down Expand Up @@ -172,6 +169,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:

# Shard the models
d_up = colwise_shard(q_up, mesh)
print("d_up weight shape:", d_up.linear.weight.shape)
d_dn = rowwise_shard(q_dn, mesh)

# We need to turn inputs into DTensor form as well -- just a format change
Expand All @@ -188,10 +186,10 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
# [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
# 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
# [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
c_up = torch.compile(d_up, backend="eager")
c_up = torch.compile(d_up)
y_up = c_up(input_dtensor)
print("y_up:", y_up.shape)
c_dn = torch.compile(d_dn, backend="eager")
c_dn = torch.compile(d_dn)
y_dn = c_dn(y_up)
print("y_dn:", y_dn.shape)
print("compiled result:", y_dn)
Expand Down

0 comments on commit 243ebc5

Please sign in to comment.