Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How does fake tensor works with tensor subclass in torch.compile? #136287

Open
jerryzh168 opened this issue Sep 18, 2024 · 5 comments
Open

How does fake tensor works with tensor subclass in torch.compile? #136287

jerryzh168 opened this issue Sep 18, 2024 · 5 comments
Labels
module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 tensor subclass Related to tensor subclasses triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 18, 2024

🐛 Describe the bug

I'm working on an example for quantized tensor subclass + DTensor (tensor parallel) + compile: pytorch/ao#785

the test works with eager mode, but failed due to a shape mismatch in compile right now.

input shape: (128, 1024), linear weight shape: (512, 1024) (out * in)

Errors in torch.mm op with fake tensor:

[rank2]:     result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]                                                                                                      12:53:17 [554/1896]
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 4333, in matmul
[rank2]:     return torch.mm(tensor1, tensor2)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 215, in dispatch
[rank2]:     local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank2]:     return self._op(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torchao-0.6.0+gitbd264f91-py3.10-linux-x86_64.egg/torchao/utils.py", line 372, in _dispatch__torch_function__
[rank2]:     return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torchao-0.6.0+gitbd264f91-py3.10-linux-x86_64.egg/torchao/utils.py", line 355, in wrapper
[rank2]:     return func(f, types, args, kwargs)
[rank2]:   File "/data/users/jerryzh/ao/tutorials/developer_api_guide/tensor_parallel.py", line 86, in _
[rank2]:     return aten.mm(input_tensor, weight_tensor)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
[rank2]:     return self._op(*args, **(kwargs or {}))
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank2]:     return fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
[rank2]:     return self.dispatch(func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank2]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
[rank2]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2039, in _dispatch_impl
[rank2]:     r = func(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank2]:     return self._op(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
[rank2]:     result = fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_meta_registrations.py", line 2100, in meta_mm
[rank2]:     torch._check(
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/__init__.py", line 1565, in _check
[rank2]:     _check_with(RuntimeError, cond, message)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/__init__.py", line 1547, in _check_with
[rank2]:     raise error_type(message_evaluated)
[rank2]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(512, 1024)), shape=torch.Size([512, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=Fa
lse), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
[rank2]: a and b must have same reduction dim, but got [128, 1024] X [512, 1024].

transpose implementation looks like the following:

@implements(aten.t.default)
def _(func, types, args, kwargs):
    tensor = args[0]
    print("before transpose, ", tensor.shape)
    shape = tensor.shape[::-1]
    new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype)
    print("after transpose:", new.shape)
    return return_and_correct_aliasing(func, args, kwargs, new)

It seems that the fake tensor did not pick up the changes to the shape in this case.

Repro:

Versions

main

cc @ezyang @albanD @chauhang @penguinwu @eellison @zou3519 @bdhirsh

@bdhirsh
Copy link
Contributor

bdhirsh commented Sep 18, 2024

I tried running the repro with compile turned off:


And I get a similar shape mismatch error from DTensor:

input tensor shape: torch.Size([128, 1024])
weight tensor shape: torch.Size([1024, 512])
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/hirsheybar/local/b/pytorch/ao/tutorials/developer_api_guide/tensor_parallel.py", line 192, in <module>
[rank2]:     y_up = c_up(input_dtensor)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/ao/tutorials/developer_api_guide/tensor_parallel.py", line 95, in forward
[rank2]:     return self.linear(x)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/nn/modules/linear.py", line 125, in forward
[rank2]:     return F.linear(input, self.weight, self.bias)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/_compile.py", line 32, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/distributed/tensor/_dispatch.py", line 215, in dispatch
[rank2]:     local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/_ops.py", line 716, in __call__
[rank2]:     return self._op(*args, **kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/ao/torchao/utils.py", line 372, in _dispatch__torch_function__
[rank2]:     return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/ao/torchao/utils.py", line 355, in wrapper
[rank2]:     return func(f, types, args, kwargs)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/ao/tutorials/developer_api_guide/tensor_parallel.py", line 86, in _
[rank2]:     return aten.mm(input_tensor, weight_tensor)
[rank2]:   File "/home/hirsheybar/local/b/pytorch/torch/_ops.py", line 1116, in __call__
[rank2]:     return self._op(*args, **(kwargs or {}))
[rank2]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x1024 and 512x1024)

So the tutorial seems like it is wrong in eager mode? If you compile a model that is supposed to give you an error about incorrect shapes, then the fake tensor error you are seeing is expected (FakeTensor will give you a similar error when it does fake tensor propagation)

@jerryzh168
Copy link
Contributor Author

I can repro this with c_up = torch.compile(d_up, backend="eager") as well, so it means it's a problem in dynamo?

@jerryzh168 jerryzh168 added tensor subclass Related to tensor subclasses module: fakeTensor labels Sep 18, 2024
@pytorch-bot pytorch-bot bot added module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 labels Sep 18, 2024
@bdhirsh
Copy link
Contributor

bdhirsh commented Sep 18, 2024

If it also repros without compile at all (e.g. removing compile and setting c_up = d_up), then that would imply that it's not a problem with compile (and the test code itself has a shape mismatch issue)

@jerryzh168
Copy link
Contributor Author

OK, it turns out this specific issue is because the transpose op is implemented as an inplace op right now, that's why it fails the second time we run it. I just updated the PR and now there is a new error

[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1595, in wrap_fake_exception
[rank0]:     return fn()
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2037, in <lambda>
[rank0]:     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2169, in run_node
[rank0]:     raise RuntimeError(make_error_message(e)).with_traceback(
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2151, in run_node
[rank0]:     return node.target(*args, **kwargs)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 271, in _fn
[rank0]:     result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 4333, in matmul
[rank0]:     return torch.mm(tensor1, tensor2)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in _torch_dispatch_
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 169, in dispatch
[rank0]:     self.sharding_propagator.propagate(op_info)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 198, in propagate
[rank0]:     output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 214, in propagate_op_sharding_non_cached
[rank0]:     out_tensor_meta = self._propagate_tensor_meta(op_schema)
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_op_schema.py", line 346, in _hash_
[rank0]:     return hash((self.op, args_to_hash))
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_dtensor_spec.py", line 68, in _hash_
[rank0]:     self._hash = self._hash_impl()
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_dtensor_spec.py", line 51, in _hash_impl
[rank0]:     return hash(
[rank0]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/__init__.py", line 531, in _hash_
[rank0]:     raise TypeError("unhashable type: non-nested SymInt")
[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, s1)), device_mesh=DeviceMesh('cuda', [0, 1, 2
, 3]), placements=(Shard(dim=1),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(1024, 512)), shape=torch.Size([1024, 512]), device=cuda:0, dtype=torch.float32, requires_grad=Fal
se), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=1),)), None), **{}):
[rank0]: unhashable type: non-nested SymInt

@bdhirsh
Copy link
Contributor

bdhirsh commented Sep 19, 2024

Oh, that latest error should actually be fixed by #136266 (comment)

@masnesral masnesral added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 tensor subclass Related to tensor subclasses triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants