Skip to content

Commit

Permalink
[Platform] Move get_punica_wrapper() function to Platform (#11516)
Browse files Browse the repository at this point in the history
Signed-off-by: Shanshan Shen <[email protected]>
  • Loading branch information
shen-shanshan authored Jan 13, 2025
1 parent 458e63a commit a7d5968
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 17 deletions.
26 changes: 9 additions & 17 deletions vllm/lora/punica_wrapper/punica_selector.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname

from .punica_base import PunicaWrapperBase

logger = init_logger(__name__)


def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
if current_platform.is_cuda_alike():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
logger.info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_cpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
logger.info_once("Using PunicaWrapperCPU.")
return PunicaWrapperCPU(*args, **kwargs)
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
logger.info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError
punica_wrapper_qualname = current_platform.get_punica_wrapper()
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
assert punica_wrapper is not None, \
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
".")
return punica_wrapper
4 changes: 4 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
return False

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
return False

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
7 changes: 7 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ def is_pin_memory_available(cls) -> bool:
return False
return True

@classmethod
def get_punica_wrapper(cls) -> str:
"""
Return the punica wrapper for current platform.
"""
raise NotImplementedError


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,7 @@ def verify_quantization(cls, quant: str) -> None:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

0 comments on commit a7d5968

Please sign in to comment.