From e930130647bda561fcf3c99c3c14166431f9176e Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Mon, 30 Dec 2024 11:09:15 +0800 Subject: [PATCH] load punica wrapper obj dynamically Signed-off-by: Shanshan Shen <467638484@qq.com> --- vllm/lora/punica_wrapper/punica_selector.py | 8 +++++++- vllm/platforms/cuda.py | 8 ++------ vllm/platforms/hpu.py | 8 ++------ vllm/platforms/interface.py | 3 +-- vllm/platforms/rocm.py | 8 ++------ 5 files changed, 14 insertions(+), 21 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index e9ee1234d0585..83a2dc5c8693f 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,7 +1,13 @@ from vllm.platforms import current_platform +from vllm.utils import print_info_once, resolve_obj_by_qualname from .punica_base import PunicaWrapperBase def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: - return current_platform.get_punica_wrapper(*args, **kwargs) + punica_wrapper_qualname = current_platform.get_punica_wrapper() + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) + punica_wrapper = punica_wrapper_cls(*args, **kwargs) + assert punica_wrapper is not None + print_info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] + ".") + return punica_wrapper diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 03f28c9d767b2..258da1f83744a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -15,9 +15,6 @@ import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase -from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU -from vllm.utils import print_info_once from .interface import DeviceCapability, Platform, PlatformEnum @@ -145,9 +142,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config.block_size = 16 @classmethod - def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase: - print_info_once("Using PunicaWrapperGPU.") - return PunicaWrapperGPU(*args, **kwargs) + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" # NVML utils diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index f095c92924758..689e380f773ea 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -3,9 +3,6 @@ import torch from vllm.logger import init_logger -from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase -from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU -from vllm.utils import print_info_once from .interface import Platform, PlatformEnum, _Backend @@ -63,6 +60,5 @@ def is_pin_memory_available(cls): return False @classmethod - def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase: - print_info_once("Using PunicaWrapperHPU.") - return PunicaWrapperHPU(*args, **kwargs) + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 29f756f784d49..c8e05d51b07f6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -8,7 +8,6 @@ import torch from vllm.logger import init_logger -from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase if TYPE_CHECKING: from vllm.config import VllmConfig @@ -240,7 +239,7 @@ def is_pin_memory_available(cls) -> bool: return True @classmethod - def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase: + def get_punica_wrapper(cls) -> str: """ Return the punica wrapper for current platform. """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b455a428e3a66..60201e4bd0fd1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -6,9 +6,6 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase -from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU -from vllm.utils import print_info_once from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -115,6 +112,5 @@ def verify_quantization(cls, quant: str) -> None: envs.VLLM_USE_TRITON_AWQ = True @classmethod - def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase: - print_info_once("Using PunicaWrapperGPU.") - return PunicaWrapperGPU(*args, **kwargs) + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"