From 9e764e7b105a483ebc702cad33922ba8d8c210e1 Mon Sep 17 00:00:00 2001 From: cennn <61925104+cennn@users.noreply.github.com> Date: Mon, 6 Jan 2025 09:05:48 +0800 Subject: [PATCH] [distributed] remove pynccl's redundant change_state (#11749) --- tests/distributed/test_pynccl.py | 64 ++++++++----------- .../device_communicators/pynccl.py | 17 ----- vllm/distributed/parallel_state.py | 9 +-- 3 files changed, 28 insertions(+), 62 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a77b48d5e49f3..a8571a1157892 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -59,8 +59,7 @@ def worker_fn(): device=get_world_group().device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - with pynccl_comm.change_state(enable=True): - tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() assert torch.all(tensor == pynccl_comm.world_size).cpu().item() @@ -81,17 +80,16 @@ def multiple_allreduce_worker_fn(): group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with pynccl_comm.change_state(enable=True): - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - tensor = pynccl_comm.all_reduce(tensor) - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 4).cpu().item() - else: - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 2).cpu().item() + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 4).cpu().item() + else: + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -137,8 +135,7 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph), \ - pynccl_comm.change_state(enable=True): + with torch.cuda.graph(graph): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize() graph.replay() @@ -167,8 +164,7 @@ def all_gather_worker_fn(): for r in range(world_size) ]).to(device) - with pynccl_comm.change_state(enable=True): - pynccl_comm.all_gather(result, tensor) + pynccl_comm.all_gather(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -205,8 +201,7 @@ def reduce_scatter_worker_fn(): expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] for tensor in all_tensors).to(device) - with pynccl_comm.change_state(enable=True): - pynccl_comm.reduce_scatter(result, tensor) + pynccl_comm.reduce_scatter(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -233,15 +228,13 @@ def send_recv_worker_fn(): else: tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - with pynccl_comm.change_state(enable=True): - if pynccl_comm.rank == 0: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % - pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % - pynccl_comm.world_size) + + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() assert torch.all(tensor == 1).cpu().item() @@ -272,15 +265,12 @@ def multiple_send_recv_worker_fn(): 1024, dtype=torch.float32, device=device) - with pynccl_comm.change_state(enable=True): - if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % - pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % - pynccl_comm.world_size) + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: assert torch.all(tensor == 1).cpu().item() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 93d96fd8f5686..fda4d007ceb5b 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -213,19 +212,3 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) - - @contextmanager - def change_state(self, enable: Optional[bool] = None): - """ - A context manager to change the state of the communicator. - """ - if enable is None: - # guess a default value when not specified - enable = self.available - - old_disable = self.disabled - - self.disabled = not enable - yield - - self.disabled = old_disable diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index dccd3addbcb35..a837c1dc5953b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -305,14 +305,7 @@ def graph_capture( stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: - pynccl_comm = self.pynccl_comm - maybe_pynccl_context: Any - if not pynccl_comm: - maybe_pynccl_context = nullcontext() - else: - maybe_pynccl_context = pynccl_comm.change_state() - with maybe_pynccl_context: - yield graph_capture_context + yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """