Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Minimum requirements for SageMaker compatibility #11576

Merged
merged 12 commits into from
Jan 2, 2025
Merged
10 changes: 9 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ RUN mv vllm test_docs/

#################### OPENAI API SERVER ####################
# openai api server alternative
FROM vllm-base AS vllm-openai

# define sagemaker first, so it is not default from `docker build`
FROM vllm-base AS vllm-sagemaker

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
Expand All @@ -247,5 +249,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \

ENV VLLM_USAGE_SOURCE production-docker-image

# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server", "--port", "8080"]

# the default image is `vllm-openai` which is identical to sagemaker without forcing the port
from vllm-sagemaker as vllm-openai

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
nathan-az marked this conversation as resolved.
Show resolved Hide resolved
#################### OPENAI API SERVER ####################
55 changes: 54 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import AsyncIterator, Optional, Set, Tuple

import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand Down Expand Up @@ -44,11 +44,15 @@
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
Expand Down Expand Up @@ -315,6 +319,12 @@ async def health(raw_request: Request) -> Response:
return Response(status_code=200)


@router.post("/ping")
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)


@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
Expand Down Expand Up @@ -488,6 +498,48 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)


TASK_HANDLERS = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
},
# should `embed` be a pooling task by default?
"embed": {
nathan-az marked this conversation as resolved.
Show resolved Hide resolved
"messages": (EmbeddingChatRequest, create_embedding),
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"messages": (PoolingChatRequest, create_score),
"default": (PoolingCompletionRequest, create_score),
nathan-az marked this conversation as resolved.
Show resolved Hide resolved
},
}


@router.post("/invocations")
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
body = await raw_request.json()
task = raw_request.app.state.task

if task not in TASK_HANDLERS:
raise HTTPException(
status_code=400,
detail=f"Unsupported task: '{task}' for '/invocations'. "
f"Expected one of {set(TASK_HANDLERS.keys())}")

handler_config = TASK_HANDLERS[task]
if "messages" in body:
request_model, handler = handler_config["messages"]
else:
request_model, handler = handler_config["default"]

# this is required since we lose the FastAPI automatic casting
request = request_model.model_validate(body)
return await handler(request, raw_request)


if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
Expand Down Expand Up @@ -694,6 +746,7 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.task = model_config.task


def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
Expand Down
Loading