Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Bugfix] Use correct device to initialize GPU data during CUDA-graph-capture #11233

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):

for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture() as graph_capture_context:
with graph_capture(device=device) as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def multiple_allreduce_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture():
with graph_capture(device=device):
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
Expand Down
7 changes: 4 additions & 3 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:


@contextmanager
def graph_capture():
def graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
Expand All @@ -934,8 +934,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
with get_tp_group().graph_capture() as context, get_pp_group(
).graph_capture(context):
context = GraphCaptureContext(torch.cuda.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
context):
yield context


Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def capture_model(self) -> None:
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture():
with graph_capture(device=self.device):
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
Expand Down
25 changes: 16 additions & 9 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,10 +1425,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:

# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_tokens = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
input_positions = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, (3, 1))
input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device)
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
previous_hidden_states = None
Expand All @@ -1447,8 +1452,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
dtype=self.model_config.dtype,
device=self.device)

with self.attn_state.graph_capture(
max_batch_size), graph_capture() as graph_capture_context:
with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
Expand Down Expand Up @@ -1548,10 +1553,12 @@ def _update_inputs_to_capture_for_enc_dec_model(self,
"""
# During the decode phase encoder_input_ids and encoder_positions are
# unset. Do the same thing for graph capture.
capture_inputs["encoder_input_ids"] = torch.tensor(
[], dtype=torch.long).cuda()
capture_inputs["encoder_positions"] = torch.tensor(
[], dtype=torch.long).cuda()
capture_inputs["encoder_input_ids"] = torch.tensor([],
dtype=torch.long,
device=self.device)
capture_inputs["encoder_positions"] = torch.tensor([],
dtype=torch.long,
device=self.device)

@property
def vocab_size(self) -> int:
Expand Down
Loading