Skip to content

Commit

Permalink
♻️ Clean up mp engine integration
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed Jan 3, 2025
1 parent 70fc214 commit a8745c0
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import shutil

import openai # use the official client for correctness check
import pytest
Expand Down Expand Up @@ -63,16 +64,16 @@ def server_with_lora_modules_json(zephyr_lora_files):


@pytest_asyncio.fixture
async def client_for_lora_lineage(server_with_lora_modules_json):
async def client(server_with_lora_modules_json):
async with server_with_lora_modules_json.get_async_client(
) as async_client:
yield async_client


@pytest.mark.asyncio
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
async def test_static_lora_lineage(client: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
Expand All @@ -87,23 +88,78 @@ async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,


@pytest.mark.asyncio
async def test_dynamic_lora_lineage(
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):

response = await client_for_lora_lineage.post("load_lora_adapter",
cast_to=str,
body={
"lora_name":
"zephyr-lora-3",
"lora_path":
zephyr_lora_files
})
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
zephyr_lora_files):

response = await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "zephyr-lora-3",
"lora_path": zephyr_lora_files
})
# Ensure adapter loads before querying /models
assert "success" in response

models = await client_for_lora_lineage.models.list()
models = await client.models.list()
models = models.data
dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files
assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3"


@pytest.mark.asyncio
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
with pytest.raises(openai.NotFoundError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "notfound",
"lora_path": "/not/an/adapter"
})


@pytest.mark.asyncio
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
tmp_path):
invalid_files = tmp_path / "invalid_files"
invalid_files.mkdir()
(invalid_files / "adapter_config.json").write_text("this is not json")

with pytest.raises(openai.BadRequestError):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_files)
})


@pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
tmp_path, zephyr_lora_files):
invalid_rank = tmp_path / "invalid_rank"

# Copy adapter from zephyr_lora_files to invalid_rank
shutil.copytree(zephyr_lora_files, invalid_rank)

with open(invalid_rank / "adapter_config.json") as f:
adapter_config = json.load(f)

print(adapter_config)

# assert False

# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
json.dump(adapter_config, f)

with pytest.raises(openai.BadRequestError,
match="is greater than max_lora_rank"):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_rank)
})
8 changes: 5 additions & 3 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def _async_serving_chat_init():
engine = MockEngine()
model_config = await engine.get_model_config()

models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
serving_completion = OpenAIServingChat(engine,
model_config,
models,
Expand All @@ -72,7 +72,8 @@ def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
Expand Down Expand Up @@ -115,7 +116,8 @@ def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
Expand Down
9 changes: 6 additions & 3 deletions tests/entrypoints/openai/test_serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
Expand All @@ -21,10 +22,12 @@

async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config = MagicMock(spec=ModelConfig)
mock_engine_client = MagicMock(spec=EngineClient)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048

serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None)
Expand Down Expand Up @@ -113,5 +116,5 @@ async def test_unload_lora_adapter_not_found():
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert response.type == "NotFoundError"
assert response.code == HTTPStatus.NOT_FOUND
8 changes: 7 additions & 1 deletion vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,16 @@ class RPCLoadAdapterRequest:
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))


@dataclass
class RPCAdapterLoadedResponse:
request_id: str


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]


def ENGINE_DEAD_ERROR(
Expand Down
21 changes: 14 additions & 7 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCLoadAdapterRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest)
Expand Down Expand Up @@ -242,16 +243,22 @@ async def run_output_handler_loop(self):
if queue is not None:
queue.put_nowait(exception)
else:
# Put each output into the appropriate steam.
for request_output in request_outputs:
queue = self.output_queues.get(
request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
# Put each output into the appropriate queue.
if isinstance(request_outputs, RPCAdapterLoadedResponse):
self._add_output(request_outputs)
else:
for request_output in request_outputs:
self._add_output(request_output)

except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")

def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse]):
queue = self.output_queues.get(request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)

async def setup(self):
"""Setup the client before it starts sending server requests."""

Expand Down
14 changes: 10 additions & 4 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCLoadAdapterRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest)
Expand Down Expand Up @@ -296,8 +297,9 @@ def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
# Otherwise, echo back the request if successful
self._send_outputs([request])
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))

def _health_check(self):
# Send unhealthy if engine has already errored
Expand All @@ -311,7 +313,11 @@ def _health_check(self):
self._send_unhealthy(e)

def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
"""Send outputs back to the engine client. These can be:
- Exceptions
- A list of generation outputs
- A response from loading a lora adapter
"""
if outputs:
try:
from ray.exceptions import RayTaskError
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ async def main(args):

# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
Expand Down
21 changes: 18 additions & 3 deletions vllm/entrypoints/openai/serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,26 @@ async def load_lora_adapter(
lora_path=lora_path)

# Validate that the adapter can be loaded into the engine
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except ValueError as e:
# Adapter not found or lora configuration errors
if "No adapter found" in str(e):
return create_error_response(message=str(e),
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
else:
return create_error_response(
message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
except BaseException as e:
# Some other unexpected problem loading the adapter, e.g. malformed
# input files.
# More detailed error messages for the user would be nicer here
return create_error_response(message=str(e),
err_type="InvalidUserInput",
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)

self.lora_requests.append(lora_request)
Expand Down Expand Up @@ -207,8 +222,8 @@ async def _check_unload_lora_adapter_request(
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)

return None

Expand Down
8 changes: 8 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
embedding_padding_modules=self.embedding_padding_modules,
weights_mapper=hf_to_vllm_mapper)

except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
# offline mode)
# - No local adapter files found at `lora_request.lora_path`
raise ValueError(
f"Loading lora {lora_request.lora_name} failed: No adapter "
f"found for {lora_path}") from e
except Exception as e:
raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
Expand Down

0 comments on commit a8745c0

Please sign in to comment.