You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py
index 91fbc396f8e..09a2bf8f183 100644
--- a/test/distributed/_tensor/test_dtensor_compile.py
+++ b/test/distributed/_tensor/test_dtensor_compile.py
@@ -544,12 +544,18 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
def test_dynamo_dtensor_from_local_redistribute(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ from torch.distributed._functional_collectives import AsyncCollectiveTensor
# pass in tensor as inputs/outputs, create DTensor and run redistribute
# (allgather collective) inside the fn
def fn(x):
dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
- return dt.redistribute(mesh, [Replicate()]).to_local() + 2
+ out = dt.redistribute(mesh, [Replicate()], async_op=True).to_local()
+ return out
+ if isinstance(out, AsyncCollectiveTensor):
+ return out.wait()
+ else:
+ return out
x = torch.ones(1)
ref = fn(x)
# run with `python test/distributed/_tensor/test_dtensor_compile.py -k test_dynamo_dtensor_from_local_redistribute`
This fails with:
File "/home/hirsheybar/local/a/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 167, in run
self.env[node] = self.run_node(node)
File "/home/hirsheybar/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6670, in run_node
result = super().run_node(n)
File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 228, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/hirsheybar/local/a/pytorch/torch/fx/interpreter.py", line 332, in call_method
return getattr(self_obj, target)(*args_tail, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='<torch._dynamo.testing.CompileCounterWithBackend object at 0x7f54e584b1c0>' raised:
AttributeError: 'FunctionalTensor' object has no attribute 'wait'
While executing %wait : [num_users=1] = call_method[target=wait](args = (%out,), kwargs = {})
Original traceback:
File "/home/hirsheybar/local/a/pytorch/test/distributed/_tensor/test_dtensor_compile.py", line 586, in fn
return out.wait()
The text was updated successfully, but these errors were encountered:
See #2352
torch-only repro:
This fails with:
The text was updated successfully, but these errors were encountered: