From 14753ba180c4dbfd1c5f78e520fdce39c672fd87 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jan 2025 15:18:21 +0800 Subject: [PATCH] [perf]fix current stream (#11870) Signed-off-by: youkaichao Signed-off-by: Fred Reiss --- .../device_communicators/pynccl.py | 15 +++++---- vllm/distributed/parallel_state.py | 5 +-- vllm/utils.py | 33 +++++++++++++++++++ vllm/worker/multi_step_model_runner.py | 8 ++--- 4 files changed, 46 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index fda4d007ceb5b..efc59987195f5 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -10,6 +10,7 @@ ncclRedOpTypeEnum, ncclUniqueId) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger +from vllm.utils import current_stream logger = init_logger(__name__) @@ -96,7 +97,7 @@ def __init__( self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank) - stream = torch.cuda.current_stream() + stream = current_stream() # A small all_reduce for warmup. data = torch.zeros(1, device=device) self.all_reduce(data) @@ -119,7 +120,7 @@ def all_reduce(self, out_tensor = torch.empty_like(in_tensor) if stream is None: - stream = torch.cuda.current_stream() + stream = current_stream() self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), buffer_type(out_tensor.data_ptr()), in_tensor.numel(), @@ -141,7 +142,7 @@ def all_gather(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = current_stream() self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), input_tensor.numel(), @@ -162,7 +163,7 @@ def reduce_scatter(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = current_stream() self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), output_tensor.numel(), @@ -177,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = current_stream() self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream)) @@ -189,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = current_stream() self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) @@ -201,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = current_stream() if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) # NCCL requires the sender also to have a receive buffer diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a837c1dc5953b..be7f16ef52a47 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -357,10 +357,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None - # TODO: pynccl should not use `stream=` - # it can just always use the current stream. - out = pynccl_comm.all_reduce(input_, - stream=torch.cuda.current_stream()) + out = pynccl_comm.all_reduce(input_) if out is None: # fall back to the default all-reduce using PyTorch. # this usually happens during testing. diff --git a/vllm/utils.py b/vllm/utils.py index a92b77efd9fd8..0b0905e675245 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -944,6 +944,39 @@ def find_nccl_library() -> str: return so_file +prev_set_stream = torch.cuda.set_stream + +_current_stream = None + + +def _patched_set_stream(stream: torch.cuda.Stream) -> None: + global _current_stream + _current_stream = stream + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +def current_stream() -> torch.cuda.Stream: + """ + replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + global _current_stream + if _current_stream is None: + # when this function is called before any stream is set, + # we return the default stream. + _current_stream = torch.cuda.current_stream() + return _current_stream + + def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index a2c2cebf8d1f6..acce923498d7e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -14,7 +14,7 @@ get_pythonized_sample_results) from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import PyObjectCache, async_tensor_h2d +from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -498,7 +498,7 @@ def execute_model( # appended sampler output from last iteration # - also maybe pythonize if CPU is ahead of GPU - current_stream = torch.cuda.current_stream() + stream = current_stream() if not model_input.is_first_multi_step: # Explicitly block on the previous step's forward to make sure we # don't clobber any GPU tensors still in use. @@ -541,7 +541,7 @@ def execute_model( num_steps=1) # record the event for the current step so that the next step can sync - model_input.record_step_event(current_stream) + model_input.record_step_event(stream) if get_pp_group().is_last_rank and self.is_driver_worker: assert isinstance(output, list) @@ -552,7 +552,7 @@ def execute_model( # event for the pythonization so that we only pythonize if the # tensors are ready. May be able to be combined with the step event output_ready_event = torch.cuda.Event() - output_ready_event.record(current_stream) + output_ready_event.record(stream) if self.parallel_config.pipeline_parallel_size > 1: output[0].sampled_token_ids_cpu = output[ 0].sampled_token_ids.cpu()