Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[perf]fix current stream #11870

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import current_stream

logger = init_logger(__name__)

Expand Down Expand Up @@ -96,7 +97,7 @@ def __init__(
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)

stream = torch.cuda.current_stream()
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
Expand All @@ -119,7 +120,7 @@ def all_reduce(self,
out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
Expand All @@ -141,7 +142,7 @@ def all_gather(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
Expand All @@ -162,7 +163,7 @@ def reduce_scatter(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
Expand All @@ -177,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -189,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -201,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
Expand Down
5 changes: 1 addition & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out = pynccl_comm.all_reduce(input_,
stream=torch.cuda.current_stream())
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
Expand Down
33 changes: 33 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,39 @@ def find_nccl_library() -> str:
return so_file


prev_set_stream = torch.cuda.set_stream

_current_stream = None


def _patched_set_stream(stream: torch.cuda.Stream) -> None:
global _current_stream
_current_stream = stream
prev_set_stream(stream)


torch.cuda.set_stream = _patched_set_stream
Comment on lines +959 to +970
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we're not using set_stream anywhere in the vllm codebase. Could you add a unit test for this to make sure it's exercised?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we patch torch.cuda.set_stream to keep track of the current stream
directly, so that we can avoid calling torch.cuda.current_stream().

I might be confused about how utils.current_stream() works though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.graph will call it internally to switch streams. so any test cases with cudagraph + nccl will test the PR's code.



def current_stream() -> torch.cuda.Stream:
"""
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
it turns out that `torch.cuda.current_stream()` is quite expensive,
as it will construct a new stream object at each call.
here we patch `torch.cuda.set_stream` to keep track of the current stream
directly, so that we can avoid calling `torch.cuda.current_stream()`.

the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
from C/C++ code.
"""
global _current_stream
if _current_stream is None:
# when this function is called before any stream is set,
# we return the default stream.
_current_stream = torch.cuda.current_stream()
return _current_stream


def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
Expand Down
8 changes: 4 additions & 4 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache, async_tensor_h2d
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -498,7 +498,7 @@ def execute_model(
# appended sampler output from last iteration
# - also maybe pythonize if CPU is ahead of GPU

current_stream = torch.cuda.current_stream()
stream = current_stream()
if not model_input.is_first_multi_step:
# Explicitly block on the previous step's forward to make sure we
# don't clobber any GPU tensors still in use.
Expand Down Expand Up @@ -541,7 +541,7 @@ def execute_model(
num_steps=1)

# record the event for the current step so that the next step can sync
model_input.record_step_event(current_stream)
model_input.record_step_event(stream)

if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
Expand All @@ -552,7 +552,7 @@ def execute_model(
# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
output_ready_event = torch.cuda.Event()
output_ready_event.record(current_stream)
output_ready_event.record(stream)
if self.parallel_config.pipeline_parallel_size > 1:
output[0].sampled_token_ids_cpu = output[
0].sampled_token_ids.cpu()
Expand Down
Loading