Skip to content

Commit

Permalink
[TPU] Async output processing for TPU (#8011)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 30, 2024
1 parent 428dd14 commit 80c7b08
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 3 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
self.use_async_output_proc = False
return

if device_config.device_type != "cuda":
if device_config.device_type not in ("cuda", "tpu"):
logger.warning(
"Async output processing is only supported for CUDA."
" Disabling it for other platforms.")
"Async output processing is only supported for CUDA or TPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return

Expand Down
8 changes: 7 additions & 1 deletion vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -51,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
best_of: List[int]
seq_groups: List[List[int]]
virtual_engine: int = 0
async_callback: Optional[Callable] = None

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
Expand Down Expand Up @@ -562,6 +564,8 @@ def _execute_model(*args):
model_input.attn_metadata, model_input.input_lens[i:i + 1],
model_input.t[i:i + 1], model_input.p[i:i + 1],
model_input.num_samples, kv_caches)
if i == 0 and model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx
Expand All @@ -572,6 +576,8 @@ def _execute_model(*args):
model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples,
kv_caches)
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist()

Expand Down

0 comments on commit 80c7b08

Please sign in to comment.