Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo doesn't handle branching on AsyncCollectiveTensor well #2353

Closed
bdhirsh opened this issue Dec 4, 2024 · 1 comment
Closed

Dynamo doesn't handle branching on AsyncCollectiveTensor well #2353

bdhirsh opened this issue Dec 4, 2024 · 1 comment

Comments

@bdhirsh
Copy link

bdhirsh commented Dec 4, 2024

See #2352

torch-only repro:

diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py
index 91fbc396f8e..09a2bf8f183 100644
--- a/test/distributed/_tensor/test_dtensor_compile.py
+++ b/test/distributed/_tensor/test_dtensor_compile.py
@@ -544,12 +544,18 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):

     def test_dynamo_dtensor_from_local_redistribute(self):
         mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        from torch.distributed._functional_collectives import AsyncCollectiveTensor

         # pass in tensor as inputs/outputs, create DTensor and run redistribute
         # (allgather collective) inside the fn
         def fn(x):
             dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
-            return dt.redistribute(mesh, [Replicate()]).to_local() + 2
+            out = dt.redistribute(mesh, [Replicate()], async_op=True).to_local()
+            return out
+            if isinstance(out, AsyncCollectiveTensor):
+                return out.wait()
+            else:
+                return out

         x = torch.ones(1)
         ref = fn(x)
         
# run with `python test/distributed/_tensor/test_dtensor_compile.py -k test_dynamo_dtensor_from_local_redistribute`

This fails with:

  File "/home/hirsheybar/local/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
  File "/home/hirsheybar/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6670, in run_node
    result = super().run_node(n)
  File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 228, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 332, in call_method
    return getattr(self_obj, target)(*args_tail, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='<torch._dynamo.testing.CompileCounterWithBackend object at 0x7f54e584b1c0>' raised:
AttributeError: 'FunctionalTensor' object has no attribute 'wait'

While executing %wait : [num_users=1] = call_method[target=wait](args = (%out,), kwargs = {})
Original traceback:
  File "/home/hirsheybar/local/a/pytorch/test/distributed/_tensor/test_dtensor_compile.py", line 586, in fn
    return out.wait()

@bdhirsh
Copy link
Author

bdhirsh commented Dec 4, 2024

whoops, I meant to file this in pytorch/pytorch - closing (pytorch/pytorch#142076)

@bdhirsh bdhirsh closed this as completed Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant