Skip to content

Commit

Permalink
[V1] [5/N] API Server: unify Detokenizer and EngineCore input (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat authored and xcnick committed Dec 31, 2024
1 parent 10c99f7 commit 6e19644
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 77 deletions.
57 changes: 35 additions & 22 deletions tests/v1/engine/test_detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytest
from transformers import AutoTokenizer

from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.engine.detokenizer import Detokenizer

TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
Expand Down Expand Up @@ -71,16 +71,22 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):

# Make N requests.
requests = [
DetokenizerRequest(
request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens,
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
) for idx, (
EngineCoreRequest(request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False))
for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
]
Expand Down Expand Up @@ -133,18 +139,25 @@ def test_stop_string(include_stop_str_in_output: bool):

# Make N requests.
requests = [
DetokenizerRequest(
EngineCoreRequest(
request_id=f"request-{idx}",
prompt=prompt,
prompt_token_ids=prompt_tokens,
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=STOP_STRINGS,
include_stop_str_in_output=include_stop_str_in_output,
) for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA,
stop=STOP_STRINGS,
include_stop_str_in_output=include_stop_str_in_output,
)) for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
]

# Add requests to the detokenizer.
Expand Down
16 changes: 1 addition & 15 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,7 @@

from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind, SamplingParams


@dataclass
class DetokenizerRequest:

request_id: str
prompt: Optional[str]
prompt_token_ids: List[int]
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind

stop: List[str]
include_stop_str_in_output: bool
from vllm.sampling_params import SamplingParams


@dataclass
Expand Down
14 changes: 8 additions & 6 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,18 @@ async def add_request(
raise ValueError(f"Request id {request_id} already running.")
self.rid_to_queue[request_id] = asyncio.Queue()

# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
# 2) Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)

# 3) Add the request to Detokenizer (this process).
self.detokenizer.add_request(detokenizer_req)
self.detokenizer.add_request(request)

# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req)
await self.engine_core.add_request_async(request)

if self.log_requests:
logger.info("Added request %s.", request_id)
Expand Down
21 changes: 11 additions & 10 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest

logger = init_logger(__name__)

Expand Down Expand Up @@ -55,19 +55,19 @@ def output_token_ids(self) -> List[int]:
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: DetokenizerRequest,
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":

tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.skip_special_tokens,
skip_special_tokens=request.sampling_params.skip_special_tokens,
)

stops = request.stop
stops = request.sampling_params.stop
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stops and not request.include_stop_str_in_output:
if stops and not request.sampling_params.include_stop_str_in_output:
stop_buffer_length = max(len(s) for s in stops) - 1
else:
stop_buffer_length = 0
Expand All @@ -79,13 +79,14 @@ def from_new_request(
# NOTE(Nick): could we take ownership of it though?
token_ids=request.prompt_token_ids.copy(),
stop=stops,
include_stop_str_in_output=request.include_stop_str_in_output,
include_stop_str_in_output=request.sampling_params.
include_stop_str_in_output,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=request.
skip_special_tokens=request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=request.sampling_params.
spaces_between_special_tokens,
output_kind=request.output_kind,
output_kind=request.sampling_params.output_kind,
request_id=request.request_id,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
Expand Down Expand Up @@ -227,7 +228,7 @@ def abort_requests(

def add_request(
self,
request: DetokenizerRequest,
request: EngineCoreRequest,
):
"""Add new request to the Detokenizer."""

Expand Down
12 changes: 7 additions & 5 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,17 @@ def add_request(
) -> None:

# 1) Process raw inputs into the request.
detokenizer_req, engine_core_req = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority)
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)

# 2) Add the request to Detokenizer.
self.detokenizer.add_request(detokenizer_req)
self.detokenizer.add_request(request)

# 3) Add the request to EngineCore.
self.engine_core.add_request(engine_core_req)
self.engine_core.add_request(request)

def step(self) -> List[RequestOutput]:

Expand Down
23 changes: 4 additions & 19 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Mapping, Optional, Tuple, Union
from typing import Mapping, Optional, Union

from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
Expand All @@ -13,7 +13,7 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient


Expand Down Expand Up @@ -62,7 +62,7 @@ def process_inputs(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Tuple[DetokenizerRequest, EngineCoreRequest]:
) -> EngineCoreRequest:

# TODO(woosuk): Support pooling models.
# TODO(woosuk): Check max_logprobs
Expand Down Expand Up @@ -123,20 +123,7 @@ def process_inputs(
decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)

# Make Request for Detokenizer.
detokenizer_request = DetokenizerRequest(
request_id,
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
sampling_params.skip_special_tokens,
sampling_params.spaces_between_special_tokens,
sampling_params.output_kind,
sampling_params.stop,
sampling_params.include_stop_str_in_output,
)

# Make Request for EngineCore.
engine_core_request = EngineCoreRequest(
return EngineCoreRequest(
request_id,
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
Expand All @@ -149,8 +136,6 @@ def process_inputs(
lora_request,
)

return detokenizer_request, engine_core_request

def _validate_model_inputs(self, inputs: ProcessorInputs):
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
Expand Down

0 comments on commit 6e19644

Please sign in to comment.