Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][3/N] API Server: Reduce Task Switching + Handle Abort Properly #11534

Merged
merged 12 commits into from
Dec 27, 2024
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class _HfExamplesInfo:
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
trust_remote_code=True),
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
Expand Down
159 changes: 62 additions & 97 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
Expand Down Expand Up @@ -54,10 +53,8 @@ def __init__(
lora_config=vllm_config.lora_config)
self.tokenizer.ping()

# Request streams (map of request_id -> AsyncStream).
self.request_streams: Dict[str, AsyncStream] = {}
# List of cancelled request ids to be aborted.
self.client_aborted_requests: List[str] = []
# Request streams (map of request_id -> queue).
self.rid_to_queue: Dict[str, asyncio.Queue] = {}

# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
Expand Down Expand Up @@ -153,14 +150,13 @@ async def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM."""

if self.detokenizer.is_request_active(request_id):
raise ValueError(f"Request {request_id} already exists.")

# 1) Create a new AsyncStream for the request.
stream = self._add_request_to_streams(request_id)
# 1) Create a new output queue for the request.
if request_id in self.rid_to_queue:
raise ValueError(f"Request id {request_id} already running.")
self.rid_to_queue[request_id] = asyncio.Queue()

# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
Expand All @@ -173,8 +169,10 @@ async def add_request(
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req)

# 5) Return the generator.
return stream.generator()
if self.log_requests:
logger.info("Added request %s.", request_id)

return self.rid_to_queue[request_id]

# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
Expand All @@ -194,7 +192,7 @@ async def generate(
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
# 2) Processing the Input.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).

Expand All @@ -206,94 +204,58 @@ async def generate(
returning the RequestOutput back to the caller.
"""

# We start the output_handler on the first call to generate() so that
# we can call __init__ before the event loop starts, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())

async for output in await self.add_request(
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())

q = await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield output

def _finish_stream(self, request_id: str):
stream = self.request_streams.pop(request_id, None)
if stream is not None:
stream.finish()

def _add_request_to_streams(
self,
request_id: str,
) -> AsyncStream:

if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")

# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs = self.client_aborted_requests
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream

if self.log_requests:
logger.info("Added request %s.", request_id)
)

return stream

async def _process_cancellations(self) -> None:
"""
Process requests cancelled from user disconnecting.

When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
self.client_aborted_requests.

As a result, if any requests are canceled from the user side
the request_id will show up in self.client_aborted_requests.
"""

# Avoid streams having circular ref to parent AsyncLLM object.
if not self.client_aborted_requests:
return
reqs_to_abort = self.client_aborted_requests.copy()
self.client_aborted_requests.clear()

# Remove from Detokenizer.
self.detokenizer.abort_requests(reqs_to_abort)

# Remove from RequestStreams.
for request_id in reqs_to_abort:
if self.log_requests:
logger.info("User-cancelled request %s.", request_id)
self._finish_stream(request_id)

# Remove from EngineCore.
await self.engine_core.abort_requests_async(reqs_to_abort)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
while True:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get()
Comment on lines +228 to +230
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯


# Note: both Detokenizer and EngineCore handle their
# own request cleanup based on finished.
if out.finished:
del self.rid_to_queue[request_id]
yield out
break

yield out

# If the request is disconnected by the client, the
# generate() task will be canceled. So, we abort the
# request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
raise

def _process_request_outputs(self, request_outputs: List[RequestOutput]):
"""Process outputs by putting them into per-request AsyncStreams."""
"""Process outputs by putting them into per-request queues."""

for request_output in request_outputs:
request_id = request_output.request_id
assert request_id in self.request_streams

# Each request in the API server pulls from the per-request stream.
stream = self.request_streams.get(request_id)
if stream is not None:
stream.put(request_output)

# If finished, remove from the tracker.
if request_output.finished:
if self.log_requests:
logger.info("Finished request %s.", request_id)
self._finish_stream(request_id)
# Note: it is possible a request was aborted and removed from
# the state due to client cancellations, so if we encounter a
# request id not in the state, we skip.
if request_id in self.rid_to_queue:
self.rid_to_queue[request_id].put_nowait(request_output)

async def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
Expand All @@ -306,24 +268,27 @@ async def _run_output_handler(self):
# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)

# 3) Put the RequestOutputs into the per-request AsyncStreams.
# 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs)

# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)

# 5) Abort any requests due to client cancellations.
await self._process_cancellations()

except BaseException as e:
logger.error(e)
raise e

# TODO: can we eliminate these?

async def abort(self, request_id: str) -> None:
# Note: Who Calls this? I dont think this is actually used.
raise ValueError("Not Supported on V1 yet.")
"""Abort RequestId in self, detokenizer, and engine core."""

request_ids = [request_id]
await self.engine_core.abort_requests_async(request_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be fire and forget via asyncio.create_task probably. but it's a tiny optimization because the rpc call is already async.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be updated in a follow up PR, so I will leave for now

self.detokenizer.abort_requests(request_ids)

# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]

def encode(
self,
Expand Down
55 changes: 0 additions & 55 deletions vllm/v1/engine/async_stream.py

This file was deleted.

2 changes: 1 addition & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5000
LOGGING_TIME_S = POLLING_TIMEOUT_S
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to change this? A little bit unintuitive on why

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously it was 5000S, this just switches t to be 5s

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah maybe just hard code this?

Copy link
Member

@ywang96 ywang96 Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Simon's question is why we're tying them when they seem unrelated to each other

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh - yeah np will revert.



class EngineCore:
Expand Down
Loading