Skip to content

Commit

Permalink
[Core][Bugfix] Use correct device to initialize GPU data during CUDA-…
Browse files Browse the repository at this point in the history
…graph-capture

Signed-off-by: Yan Burman <[email protected]>
Signed-off-by: Ido Asraff <[email protected]>
  • Loading branch information
yanburman committed Jan 1, 2025
1 parent 6d70198 commit c62131a
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):

for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture() as graph_capture_context:
with graph_capture(device=device) as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def multiple_allreduce_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture():
with graph_capture(device=device):
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
Expand Down
7 changes: 4 additions & 3 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:


@contextmanager
def graph_capture():
def graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
Expand All @@ -934,8 +934,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
with get_tp_group().graph_capture() as context, get_pp_group(
).graph_capture(context):
context = GraphCaptureContext(torch.cuda.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
context):
yield context


Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def capture_model(self) -> None:
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture():
with graph_capture(device=self.device):
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
Expand Down
25 changes: 16 additions & 9 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,10 +1425,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:

# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_tokens = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
input_positions = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, (3, 1))
input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device)
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
previous_hidden_states = None
Expand All @@ -1447,8 +1452,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
dtype=self.model_config.dtype,
device=self.device)

with self.attn_state.graph_capture(
max_batch_size), graph_capture() as graph_capture_context:
with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
Expand Down Expand Up @@ -1548,10 +1553,12 @@ def _update_inputs_to_capture_for_enc_dec_model(self,
"""
# During the decode phase encoder_input_ids and encoder_positions are
# unset. Do the same thing for graph capture.
capture_inputs["encoder_input_ids"] = torch.tensor(
[], dtype=torch.long).cuda()
capture_inputs["encoder_positions"] = torch.tensor(
[], dtype=torch.long).cuda()
capture_inputs["encoder_input_ids"] = torch.tensor([],
dtype=torch.long,
device=self.device)
capture_inputs["encoder_positions"] = torch.tensor([],
dtype=torch.long,
device=self.device)

@property
def vocab_size(self) -> int:
Expand Down

0 comments on commit c62131a

Please sign in to comment.