-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][3/N] API Server: Reduce Task Switching + Handle Abort Properly #11534
Changes from all commits
509f488
1789162
9239814
83acac6
aefeb84
8fb01d1
aeefcf2
85f8ac7
1749da1
e0ddb05
60ef7aa
e02717e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,14 +9,13 @@ | |
from vllm.inputs.preprocess import InputPreprocessor | ||
from vllm.logger import init_logger | ||
from vllm.lora.request import LoRARequest | ||
from vllm.outputs import PoolingRequestOutput, RequestOutput | ||
from vllm.outputs import RequestOutput | ||
from vllm.pooling_params import PoolingParams | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.transformers_utils.tokenizer import AnyTokenizer | ||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs | ||
from vllm.usage.usage_lib import UsageContext | ||
from vllm.v1.engine.async_stream import AsyncStream | ||
from vllm.v1.engine.core_client import EngineCoreClient | ||
from vllm.v1.engine.detokenizer import Detokenizer | ||
from vllm.v1.engine.processor import Processor | ||
|
@@ -54,10 +53,8 @@ def __init__( | |
lora_config=vllm_config.lora_config) | ||
self.tokenizer.ping() | ||
|
||
# Request streams (map of request_id -> AsyncStream). | ||
self.request_streams: Dict[str, AsyncStream] = {} | ||
# List of cancelled request ids to be aborted. | ||
self.client_aborted_requests: List[str] = [] | ||
# Request streams (map of request_id -> queue). | ||
self.rid_to_queue: Dict[str, asyncio.Queue] = {} | ||
|
||
# Processor (converts Inputs --> EngineCoreRequests). | ||
self.processor = Processor( | ||
|
@@ -153,14 +150,13 @@ async def add_request( | |
trace_headers: Optional[Mapping[str, str]] = None, | ||
prompt_adapter_request: Optional[PromptAdapterRequest] = None, | ||
priority: int = 0, | ||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: | ||
) -> asyncio.Queue[RequestOutput]: | ||
"""Add new request to the AsyncLLM.""" | ||
|
||
if self.detokenizer.is_request_active(request_id): | ||
raise ValueError(f"Request {request_id} already exists.") | ||
|
||
# 1) Create a new AsyncStream for the request. | ||
stream = self._add_request_to_streams(request_id) | ||
# 1) Create a new output queue for the request. | ||
if request_id in self.rid_to_queue: | ||
raise ValueError(f"Request id {request_id} already running.") | ||
self.rid_to_queue[request_id] = asyncio.Queue() | ||
|
||
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest. | ||
detokenizer_req, engine_core_req = self.processor.process_inputs( | ||
|
@@ -173,8 +169,10 @@ async def add_request( | |
# 4) Add the EngineCoreRequest to EngineCore (separate process). | ||
await self.engine_core.add_request_async(engine_core_req) | ||
|
||
# 5) Return the generator. | ||
return stream.generator() | ||
if self.log_requests: | ||
logger.info("Added request %s.", request_id) | ||
|
||
return self.rid_to_queue[request_id] | ||
|
||
# TODO: we should support multiple prompts in one call, as you | ||
# can do with LLM.generate. So that for multi-prompt completion | ||
|
@@ -194,7 +192,7 @@ async def generate( | |
""" | ||
Main function called by the API server to kick off a request | ||
* 1) Making an AsyncStream corresponding to the Request. | ||
# 2) Processing the Input. | ||
* 2) Processing the Input. | ||
* 3) Adding the Request to the Detokenizer. | ||
* 4) Adding the Request to the EngineCore (separate process). | ||
|
||
|
@@ -206,94 +204,58 @@ async def generate( | |
returning the RequestOutput back to the caller. | ||
""" | ||
|
||
# We start the output_handler on the first call to generate() so that | ||
# we can call __init__ before the event loop starts, which enables us | ||
# to handle startup failure gracefully in the OpenAI server. | ||
if self.output_handler is None: | ||
self.output_handler = asyncio.create_task( | ||
self._run_output_handler()) | ||
|
||
async for output in await self.add_request( | ||
try: | ||
# We start the output_handler on the first call to generate() so | ||
# we can call __init__ before the event loop, which enables us | ||
# to handle startup failure gracefully in the OpenAI server. | ||
if self.output_handler is None: | ||
self.output_handler = asyncio.create_task( | ||
self._run_output_handler()) | ||
|
||
q = await self.add_request( | ||
request_id, | ||
prompt, | ||
sampling_params, | ||
lora_request=lora_request, | ||
trace_headers=trace_headers, | ||
prompt_adapter_request=prompt_adapter_request, | ||
priority=priority, | ||
): | ||
yield output | ||
|
||
def _finish_stream(self, request_id: str): | ||
stream = self.request_streams.pop(request_id, None) | ||
if stream is not None: | ||
stream.finish() | ||
|
||
def _add_request_to_streams( | ||
self, | ||
request_id: str, | ||
) -> AsyncStream: | ||
|
||
if request_id in self.request_streams: | ||
raise ValueError(f"Request id {request_id} already running.") | ||
|
||
# Avoid streams having circular ref to parent AsyncLLM object. | ||
aborted_reqs = self.client_aborted_requests | ||
stream = AsyncStream(request_id, aborted_reqs.append) | ||
self.request_streams[request_id] = stream | ||
|
||
if self.log_requests: | ||
logger.info("Added request %s.", request_id) | ||
) | ||
|
||
return stream | ||
|
||
async def _process_cancellations(self) -> None: | ||
""" | ||
Process requests cancelled from user disconnecting. | ||
|
||
When a client disconnects, AsyncStream._cancel() is called. | ||
We passed a callback to AsyncStream(), which appends to | ||
self.client_aborted_requests. | ||
|
||
As a result, if any requests are canceled from the user side | ||
the request_id will show up in self.client_aborted_requests. | ||
""" | ||
|
||
# Avoid streams having circular ref to parent AsyncLLM object. | ||
if not self.client_aborted_requests: | ||
return | ||
reqs_to_abort = self.client_aborted_requests.copy() | ||
self.client_aborted_requests.clear() | ||
|
||
# Remove from Detokenizer. | ||
self.detokenizer.abort_requests(reqs_to_abort) | ||
|
||
# Remove from RequestStreams. | ||
for request_id in reqs_to_abort: | ||
if self.log_requests: | ||
logger.info("User-cancelled request %s.", request_id) | ||
self._finish_stream(request_id) | ||
|
||
# Remove from EngineCore. | ||
await self.engine_core.abort_requests_async(reqs_to_abort) | ||
# The output_handler task pushes items into the queue. | ||
# This task pulls from the queue and yields to caller. | ||
while True: | ||
# Note: drain queue without await if possible (avoids | ||
# task switching under load which helps performance). | ||
out = q.get_nowait() if q.qsize() > 0 else await q.get() | ||
|
||
# Note: both Detokenizer and EngineCore handle their | ||
# own request cleanup based on finished. | ||
if out.finished: | ||
del self.rid_to_queue[request_id] | ||
yield out | ||
break | ||
|
||
yield out | ||
|
||
# If the request is disconnected by the client, the | ||
# generate() task will be canceled. So, we abort the | ||
# request if we end up here. | ||
except asyncio.CancelledError: | ||
await self.abort(request_id) | ||
raise | ||
|
||
def _process_request_outputs(self, request_outputs: List[RequestOutput]): | ||
"""Process outputs by putting them into per-request AsyncStreams.""" | ||
"""Process outputs by putting them into per-request queues.""" | ||
|
||
for request_output in request_outputs: | ||
request_id = request_output.request_id | ||
assert request_id in self.request_streams | ||
|
||
# Each request in the API server pulls from the per-request stream. | ||
stream = self.request_streams.get(request_id) | ||
if stream is not None: | ||
stream.put(request_output) | ||
|
||
# If finished, remove from the tracker. | ||
if request_output.finished: | ||
if self.log_requests: | ||
logger.info("Finished request %s.", request_id) | ||
self._finish_stream(request_id) | ||
# Note: it is possible a request was aborted and removed from | ||
# the state due to client cancellations, so if we encounter a | ||
# request id not in the state, we skip. | ||
if request_id in self.rid_to_queue: | ||
self.rid_to_queue[request_id].put_nowait(request_output) | ||
|
||
async def _run_output_handler(self): | ||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams.""" | ||
|
@@ -306,24 +268,27 @@ async def _run_output_handler(self): | |
# 2) Detokenize based on the output. | ||
request_outputs, reqs_to_abort = self.detokenizer.step(outputs) | ||
|
||
# 3) Put the RequestOutputs into the per-request AsyncStreams. | ||
# 3) Put the RequestOutputs into the per-request queues. | ||
self._process_request_outputs(request_outputs) | ||
|
||
# 4) Abort any requests that finished due to stop strings. | ||
await self.engine_core.abort_requests_async(reqs_to_abort) | ||
|
||
# 5) Abort any requests due to client cancellations. | ||
await self._process_cancellations() | ||
|
||
except BaseException as e: | ||
logger.error(e) | ||
raise e | ||
|
||
# TODO: can we eliminate these? | ||
|
||
async def abort(self, request_id: str) -> None: | ||
# Note: Who Calls this? I dont think this is actually used. | ||
raise ValueError("Not Supported on V1 yet.") | ||
"""Abort RequestId in self, detokenizer, and engine core.""" | ||
|
||
request_ids = [request_id] | ||
await self.engine_core.abort_requests_async(request_ids) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be fire and forget via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is going to be updated in a follow up PR, so I will leave for now |
||
self.detokenizer.abort_requests(request_ids) | ||
|
||
# If a request finishes while we await then the request_id | ||
# will be removed from the tracked queues before we get here. | ||
if request_id in self.rid_to_queue: | ||
del self.rid_to_queue[request_id] | ||
|
||
def encode( | ||
self, | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ | |
|
||
POLLING_TIMEOUT_MS = 5000 | ||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 | ||
LOGGING_TIME_S = 5000 | ||
LOGGING_TIME_S = POLLING_TIMEOUT_S | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to change this? A little bit unintuitive on why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously it was 5000S, this just switches t to be 5s There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah maybe just hard code this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Simon's question is why we're tying them when they seem unrelated to each other There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh - yeah np will revert. |
||
|
||
|
||
class EngineCore: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤯