Skip to content

Commit

Permalink
[V1] Refactor get_executor_cls (#11754)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruisearch42 authored Jan 6, 2025
1 parent f8fcca1 commit 022c5c6
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 46 deletions.
6 changes: 3 additions & 3 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_engine_core(monkeypatch):
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class)
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class)
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.executor.abstract import Executor

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
Expand Down Expand Up @@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
Expand Down
21 changes: 1 addition & 20 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import initialize_ray_cluster

logger = init_logger(__name__)

Expand Down Expand Up @@ -105,7 +104,7 @@ def from_engine_args(
else:
vllm_config = engine_config

executor_class = cls._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

# Create the AsyncLLM.
return cls(
Expand All @@ -127,24 +126,6 @@ def shutdown(self):
if handler := getattr(self, "output_handler", None):
handler.cancel()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
initialize_ray_cluster(vllm_config.parallel_config)
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class

async def add_request(
self,
request_id: str,
Expand Down
20 changes: 1 addition & 19 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def from_engine_args(

# Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config)
executor_class = Executor.get_class(vllm_config)

if VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
Expand All @@ -103,24 +103,6 @@ def from_engine_args(
stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing)

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor

return executor_class

def get_num_unfinished_requests(self) -> int:
return self.detokenizer.get_num_unfinished_requests()

Expand Down
19 changes: 18 additions & 1 deletion vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Tuple
from typing import Tuple, Type

from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput
Expand All @@ -8,6 +8,23 @@
class Executor(ABC):
"""Abstract class for executors."""

@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class

@abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError
Expand Down

0 comments on commit 022c5c6

Please sign in to comment.