diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 6817bce0235..778347b8ef3 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -54,11 +54,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me )._prepare_output_fn( output_layouts, use_local_output, mod, outputs, device_mesh ) - # wait for the output to be ready - if isinstance(outputs, AsyncCollectiveTensor): - return outputs.wait() - else: - return outputs + return torch.distributed._functional_collectives.wait_tensor(outputs) def tensor_parallel(