From 64bfd4551e1a05d5aa82441c3d68ef9920918422 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Tue, 24 Dec 2024 19:49:46 -0800 Subject: [PATCH] [V1] Unify VLLM_ENABLE_V1_MULTIPROCESSING handling in RayExecutor (#11472) --- tests/basic_correctness/test_basic_correctness.py | 5 ----- vllm/v1/engine/llm_engine.py | 2 -- vllm/v1/executor/ray_executor.py | 5 ++++- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 9e4eb16fc6cc5..1c2193bb17a55 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -127,11 +127,6 @@ def test_models_distributed( if attention_backend: os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend - # Import VLLM_USE_V1 dynamically to handle patching - from vllm.envs import VLLM_USE_V1 - if VLLM_USE_V1 and distributed_executor_backend != "mp": - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - dtype = "half" max_tokens = 5 diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 9ad51575b3cc3..b58f62778ffe9 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,7 +21,6 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.executor.ray_utils import initialize_ray_cluster logger = init_logger(__name__) @@ -112,7 +111,6 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) if distributed_executor_backend == "ray": - initialize_ray_cluster(vllm_config.parallel_config) from vllm.v1.executor.ray_executor import RayExecutor executor_class = RayExecutor elif distributed_executor_backend == "mp": diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index dfeb69fa701a3..79acc60001c99 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -8,7 +8,8 @@ from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.executor.abstract import Executor -from vllm.v1.executor.ray_utils import RayWorkerWrapper, ray +from vllm.v1.executor.ray_utils import (RayWorkerWrapper, + initialize_ray_cluster, ray) from vllm.v1.outputs import ModelRunnerOutput if ray is not None: @@ -33,7 +34,9 @@ def __init__(self, vllm_config: VllmConfig) -> None: if ray_usage != "1": os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + initialize_ray_cluster(self.parallel_config) placement_group = self.parallel_config.placement_group + # Create the parallel GPU workers. self._init_workers_ray(placement_group)