Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Minimum requirements for SageMaker compatibility #11576

Merged
merged 12 commits into from
Jan 2, 2025
Merged
13 changes: 11 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ RUN mv vllm test_docs/
#################### TEST IMAGE ####################

#################### OPENAI API SERVER ####################
# openai api server alternative
FROM vllm-base AS vllm-openai
# base openai image with additional requirements, for any subsequent openai-style images
FROM vllm-base AS vllm-openai-base

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
Expand All @@ -247,5 +247,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \

ENV VLLM_USAGE_SOURCE production-docker-image

# define sagemaker first, so it is not default from `docker build`
FROM vllm-openai-base AS vllm-sagemaker

COPY ./sagemaker-entrypoint.sh .
RUN chmod +x sagemaker-entrypoint.sh
ENTRYPOINT ["./sagemaker-entrypoint.sh"]

from vllm-openai-base as vllm-openai
simon-mo marked this conversation as resolved.
Show resolved Hide resolved

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
nathan-az marked this conversation as resolved.
Show resolved Hide resolved
#################### OPENAI API SERVER ####################
24 changes: 24 additions & 0 deletions sagemaker-entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

# Define the prefix for environment variables to look for
PREFIX="SM_VLLM_"
ARG_PREFIX="--"

# Initialize an array for storing the arguments
# port 8080 required by sagemaker, https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
ARGS=(--port 8080)

# Loop through all environment variables
while IFS='=' read -r key value; do
# Remove the prefix from the key, convert to lowercase, and replace underscores with dashes
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')

# Add the argument name and value to the ARGS array
ARGS+=("${ARG_PREFIX}${arg_name}")
if [ -n "$value" ]; then
ARGS+=("$value")
fi
done < <(env | grep "^${PREFIX}")

# Pass the collected arguments to the main entrypoint
exec python3 -m vllm.entrypoints.openai.api_server "${ARGS[@]}"
61 changes: 60 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import AsyncIterator, Optional, Set, Tuple

import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand Down Expand Up @@ -44,11 +44,15 @@
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
Expand Down Expand Up @@ -315,6 +319,12 @@ async def health(raw_request: Request) -> Response:
return Response(status_code=200)


@router.api_route("/ping", methods=["GET", "POST"])
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)


@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
Expand Down Expand Up @@ -488,6 +498,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)


TASK_HANDLERS = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
},
"embed": {
nathan-az marked this conversation as resolved.
Show resolved Hide resolved
"messages": (EmbeddingChatRequest, create_embedding),
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"default": (ScoreRequest, create_score),
},
"reward": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
"classify": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
}


@router.post("/invocations")
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
body = await raw_request.json()
task = raw_request.app.state.task

if task not in TASK_HANDLERS:
raise HTTPException(
status_code=400,
detail=f"Unsupported task: '{task}' for '/invocations'. "
f"Expected one of {set(TASK_HANDLERS.keys())}")

handler_config = TASK_HANDLERS[task]
if "messages" in body:
request_model, handler = handler_config["messages"]
else:
request_model, handler = handler_config["default"]

# this is required since we lose the FastAPI automatic casting
request = request_model.model_validate(body)
return await handler(request, raw_request)


if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
Expand Down Expand Up @@ -694,6 +752,7 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.task = model_config.task


def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
Expand Down
Loading