From 243ebc581f0c0908bbe675af6ec4a6955f712233 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 18 Sep 2024 17:37:40 -0700 Subject: [PATCH] change transpose to non-inplace op --- .../developer_api_guide/my_dtype_tensor_subclass.py | 5 ++--- tutorials/developer_api_guide/tensor_parallel.py | 10 ++++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index af027275b6..fc967441ed 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -194,7 +194,7 @@ def dequantize(self, output_dtype=None): block_size = (1, int_data.shape[-1]) if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed: transposed = True - res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype) + res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype) if transposed: res = res.t() return res @@ -331,8 +331,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif dim == 1: return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type) elif func is aten.t.default: - args[0].transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, args[0]) + return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) raise NotImplementedError( f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported" diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index 233d75ff10..577b7e5b6c 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -57,10 +57,8 @@ def _(func, types, args, kwargs): @implements(aten.t.default) def _(func, types, args, kwargs): tensor = args[0] - print("before transpose, ", tensor.shape) shape = tensor.shape[::-1] new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype) - print("after transpose:", new.shape) return return_and_correct_aliasing(func, args, kwargs, new) @implements(aten.addmm.default) @@ -80,8 +78,7 @@ def _(func, types, args, kwargs): args[1], None ) - print("mm input tensor shape:", input_tensor.shape) - print("mm weight tensor shape:", weight_tensor.shape) + print("mm weight transposed:", weight_tensor.layout_tensor.transposed) weight_tensor = weight_tensor.dequantize() return aten.mm(input_tensor, weight_tensor) @@ -172,6 +169,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Shard the models d_up = colwise_shard(q_up, mesh) + print("d_up weight shape:", d_up.linear.weight.shape) d_dn = rowwise_shard(q_dn, mesh) # We need to turn inputs into DTensor form as well -- just a format change @@ -188,10 +186,10 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1, # 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}): # [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024]. - c_up = torch.compile(d_up, backend="eager") + c_up = torch.compile(d_up) y_up = c_up(input_dtensor) print("y_up:", y_up.shape) - c_dn = torch.compile(d_dn, backend="eager") + c_dn = torch.compile(d_dn) y_dn = c_dn(y_up) print("y_dn:", y_dn.shape) print("compiled result:", y_dn)