Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch TP composable with torch.compile #2352

Merged
merged 1 commit into from
Dec 5, 2024

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Dec 4, 2024

Motivation

Previously we have this block of code in RowwiseParallel:

# wait for the output to be ready
if isinstance(outputs, AsyncCollectiveTensor):
  return outputs.wait()
else:
  return outputs

When dynamo traces, outputs is an AsyncCollectiveTensor, so it burns in the outputs.wait() call. But at run time, somehow outputs become a normal tensor, thus we hit the following error:

AttributeError: 'FunctionalTensor' object has no attribute 'wait'

This is a bug in Dynamo.

Modifications

To work around the above Dynamo bug, we swap the above if-else block with a determined behavior:

torch.distributed._functional_collectives.wait_tensor(outputs)

wait_tensor accepts both regular tensor and AsyncCollectiveTensor. How it handles them is implementation detail it keeps inside. This would work in both eager and compiled mode.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

cc @jerryzh168 @bdhirsh

@bdhirsh
Copy link

bdhirsh commented Dec 4, 2024

Here's a minimal repro:

diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py
index 91fbc396f8e..09a2bf8f183 100644
--- a/test/distributed/_tensor/test_dtensor_compile.py
+++ b/test/distributed/_tensor/test_dtensor_compile.py
@@ -544,12 +544,18 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):

     def test_dynamo_dtensor_from_local_redistribute(self):
         mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        from torch.distributed._functional_collectives import AsyncCollectiveTensor

         # pass in tensor as inputs/outputs, create DTensor and run redistribute
         # (allgather collective) inside the fn
         def fn(x):
             dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
-            return dt.redistribute(mesh, [Replicate()]).to_local() + 2
+            out = dt.redistribute(mesh, [Replicate()], async_op=True).to_local()
+            return out
+            if isinstance(out, AsyncCollectiveTensor):
+                return out.wait()
+            else:
+                return out

         x = torch.ones(1)
         ref = fn(x)

@bdhirsh
Copy link

bdhirsh commented Dec 4, 2024

Fixed in core with pytorch/pytorch#142075

@merrymercy merrymercy merged commit d693ec0 into sgl-project:main Dec 5, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants