Skip to content

Commit

Permalink
refactor get_punica_wrapper() into platform
Browse files Browse the repository at this point in the history
Signed-off-by: Shanshan Shen <[email protected]>
  • Loading branch information
shen-shanshan committed Dec 26, 2024
1 parent 12b41b1 commit 69797a8
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 38 deletions.
14 changes: 1 addition & 13 deletions vllm/lora/punica_wrapper/punica_selector.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
from vllm.platforms import current_platform
from vllm.utils import print_info_once

from .punica_base import PunicaWrapperBase


def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
if current_platform.is_cuda_alike():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
print_info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError
return current_platform.get_punica_wrapper(*args, **kwargs)
7 changes: 7 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch

from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.utils import print_info_once

from .interface import Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -106,3 +108,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
return False

@classmethod
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("PunicaWrapperCPU is not implemented yet.")
raise NotImplementedError
11 changes: 6 additions & 5 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import vllm._C # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
from vllm.utils import print_info_once

from .interface import DeviceCapability, Platform, PlatformEnum

Expand Down Expand Up @@ -142,11 +145,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config.block_size = 16

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)


# NVML utils
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch

from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
from vllm.utils import print_info_once

from .interface import Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -58,3 +61,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
return False

@classmethod
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
7 changes: 3 additions & 4 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -239,11 +240,9 @@ def is_pin_memory_available(cls) -> bool:
return True

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
"""
Return the memory usage in bytes.
Return the punica wrapper for the current platform.
"""
raise NotImplementedError

Expand Down
11 changes: 6 additions & 5 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
from vllm.utils import print_info_once

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -112,8 +115,6 @@ def verify_quantization(cls, quant: str) -> None:
envs.VLLM_USE_TRITON_AWQ = True

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
def get_punica_wrapper(cls, *args, **kwargs) -> PunicaWrapperBase:
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
7 changes: 0 additions & 7 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,3 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.")
return False

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.xpu.reset_peak_memory_stats(device)
return torch.xpu.max_memory_allocated(device)
16 changes: 12 additions & 4 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,15 +680,23 @@ class DeviceMemoryProfiler:
def __init__(self, device: Optional[torch.types.Device] = None):
self.device = device

def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
if current_platform.is_cuda_alike():
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif current_platform.is_xpu():
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem

def __enter__(self):
self.initial_memory = current_platform.get_current_memory_usage(
self.device)
self.initial_memory = self.current_memory_usage()
# This allows us to call methods of the context manager if needed
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.final_memory = current_platform.get_current_memory_usage(
self.device)
self.final_memory = self.current_memory_usage()
self.consumed_memory = self.final_memory - self.initial_memory

# Force garbage collection
Expand Down

0 comments on commit 69797a8

Please sign in to comment.