Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Validate lora adapters to avoid crashing server #11727

Merged
merged 14 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import shutil

import openai # use the official client for correctness check
import pytest
Expand Down Expand Up @@ -63,16 +64,16 @@ def server_with_lora_modules_json(zephyr_lora_files):


@pytest_asyncio.fixture
async def client_for_lora_lineage(server_with_lora_modules_json):
async def client(server_with_lora_modules_json):
async with server_with_lora_modules_json.get_async_client(
) as async_client:
yield async_client


@pytest.mark.asyncio
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
async def test_static_lora_lineage(client: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
Expand All @@ -87,23 +88,78 @@ async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,


@pytest.mark.asyncio
async def test_dynamic_lora_lineage(
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):

response = await client_for_lora_lineage.post("load_lora_adapter",
cast_to=str,
body={
"lora_name":
"zephyr-lora-3",
"lora_path":
zephyr_lora_files
})
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
zephyr_lora_files):

response = await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "zephyr-lora-3",
"lora_path": zephyr_lora_files
})
# Ensure adapter loads before querying /models
assert "success" in response

models = await client_for_lora_lineage.models.list()
models = await client.models.list()
models = models.data
dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files
assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3"


@pytest.mark.asyncio
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
with pytest.raises(openai.NotFoundError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "notfound",
"lora_path": "/not/an/adapter"
})


@pytest.mark.asyncio
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
tmp_path):
invalid_files = tmp_path / "invalid_files"
invalid_files.mkdir()
(invalid_files / "adapter_config.json").write_text("this is not json")

with pytest.raises(openai.BadRequestError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_files)
})


@pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
tmp_path, zephyr_lora_files):
invalid_rank = tmp_path / "invalid_rank"

# Copy adapter from zephyr_lora_files to invalid_rank
shutil.copytree(zephyr_lora_files, invalid_rank)

with open(invalid_rank / "adapter_config.json") as f:
adapter_config = json.load(f)

print(adapter_config)

# assert False

# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
json.dump(adapter_config, f)

with pytest.raises(openai.BadRequestError,
match="is greater than max_lora_rank"):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_rank)
})
8 changes: 5 additions & 3 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def _async_serving_chat_init():
engine = MockEngine()
model_config = await engine.get_model_config()

models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
serving_completion = OpenAIServingChat(engine,
model_config,
models,
Expand All @@ -72,7 +72,8 @@ def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
Expand Down Expand Up @@ -115,7 +116,8 @@ def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
Expand Down
9 changes: 6 additions & 3 deletions tests/entrypoints/openai/test_serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
Expand All @@ -21,10 +22,12 @@

async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config = MagicMock(spec=ModelConfig)
mock_engine_client = MagicMock(spec=EngineClient)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048

serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None)
Expand Down Expand Up @@ -113,5 +116,5 @@ async def test_unload_lora_adapter_not_found():
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert response.type == "NotFoundError"
assert response.code == HTTPStatus.NOT_FOUND
4 changes: 4 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,10 @@ async def stop_profile(self) -> None:
else:
self.engine.model_executor._run_workers("stop_profile")

async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
self.engine.add_lora(lora_request)


# TODO(v1): Remove this class proxy when V1 goes default.
if envs.VLLM_USE_V1:
Expand Down
20 changes: 17 additions & 3 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Mapping, Optional, Union, overload

Expand Down Expand Up @@ -120,10 +121,23 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2


@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))


@dataclass
class RPCAdapterLoadedResponse:
request_id: str


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest]
RPCUProfileRequest, RPCLoadAdapterRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]


def ENGINE_DEAD_ERROR(
Expand Down
45 changes: 37 additions & 8 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
Expand Down Expand Up @@ -241,16 +243,22 @@ async def run_output_handler_loop(self):
if queue is not None:
queue.put_nowait(exception)
else:
# Put each output into the appropriate steam.
for request_output in request_outputs:
queue = self.output_queues.get(
request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
# Put each output into the appropriate queue.
if isinstance(request_outputs, RPCAdapterLoadedResponse):
joerunde marked this conversation as resolved.
Show resolved Hide resolved
self._add_output(request_outputs)
else:
for request_output in request_outputs:
self._add_output(request_output)

except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")

def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse]):
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)

async def setup(self):
"""Setup the client before it starts sending server requests."""

Expand Down Expand Up @@ -659,3 +667,24 @@ async def stop_profile(self) -> None:

await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)

async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)

# Create output queue for this requests.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
self.output_queues[request.request_id] = queue

# Send the request
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)

# Wait for the response
request_output = await queue.get()
self.output_queues.pop(request.request_id)

# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
27 changes: 24 additions & 3 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
Expand Down Expand Up @@ -234,6 +236,8 @@ def handle_new_input(self):
self.start_profile()
else:
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
Expand Down Expand Up @@ -284,6 +288,19 @@ def _handle_abort_request(self, request: RPCAbortRequest):
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)

def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try:
self.engine.add_lora(request.lora_request)
except BaseException as e:
# Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id,
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))

def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
Expand All @@ -296,7 +313,11 @@ def _health_check(self):
self._send_unhealthy(e)

def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs:
try:
from ray.exceptions import RayTaskError
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,8 @@ async def start_profile(self) -> None:
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...

@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,12 +631,12 @@ def init_app_state(
logger.info("Using supplied chat template:\n%s", resolved_chat_template)

state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
# TODO: The chat template is now broken for lora adapters :(
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ async def main(args):

# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
Expand Down
Loading
Loading