Skip to content

Commit

Permalink
[Misc][LoRA] Ensure Lora Adapter requests return adapter name (#11094)
Browse files Browse the repository at this point in the history
Signed-off-by: Jiaxin Shan <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Signed-off-by: Akshat Tripathi <[email protected]>
  • Loading branch information
2 people authored and Akshat-Tripathi committed Dec 12, 2024
1 parent 194e7c5 commit 59d5e84
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
11 changes: 11 additions & 0 deletions tests/entrypoints/openai/test_serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.lora.request import LoRARequest

MODEL_NAME = "meta-llama/Llama-2-7b"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
Expand All @@ -33,6 +34,16 @@ async def _async_serving_engine_init():
return serving_engine


@pytest.mark.asyncio
async def test_serving_model_name():
serving_engine = await _async_serving_engine_init()
assert serving_engine._get_model_name(None) == MODEL_NAME
request = LoRARequest(lora_name="adapter",
lora_path="/path/to/adapter2",
lora_int_id=1)
assert serving_engine._get_model_name(request) == request.lora_name


@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
Expand Down
14 changes: 8 additions & 6 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ async def create_chat_completion(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

model_name = self._get_model_name(lora_request)

tokenizer = await self.engine_client.get_tokenizer(lora_request)

tool_parser = self.tool_parser
Expand Down Expand Up @@ -238,13 +240,13 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
request, result_generator, request_id, model_name,
conversation, tokenizer, request_metadata)

try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
request, result_generator, request_id, model_name,
conversation, tokenizer, request_metadata)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -259,11 +261,11 @@ async def chat_completion_stream_generator(
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
Expand Down Expand Up @@ -604,12 +606,12 @@ async def chat_completion_full_generator(
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.base_model_paths[0].name
created_time = int(time.time())
final_res: Optional[RequestOutput] = None

Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ async def create_completion(
return self.create_error_response(
"suffix is not currently supported")

model_name = self.base_model_paths[0].name
request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time())

Expand Down Expand Up @@ -162,6 +161,7 @@ async def create_completion(
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)

model_name = self._get_model_name(lora_request)
num_prompts = len(engine_prompts)

# Similar to the OpenAI API, when n != best_of, we do not stream the
Expand Down
13 changes: 13 additions & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,16 @@ async def unload_lora_adapter(

def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)

def _get_model_name(self, lora: Optional[LoRARequest]):
"""
Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora is not None:
return lora.lora_name
return self.base_model_paths[0].name

0 comments on commit 59d5e84

Please sign in to comment.