Skip to content

Commit

Permalink
[platform] Allow platform specify attention backend (vllm-project#11609)
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
  • Loading branch information
2 people authored and frreiss committed Jan 10, 2025
1 parent a7e258c commit 38a0be3
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 175 deletions.
74 changes: 42 additions & 32 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL


@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
Expand All @@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch):

if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
OpenVinoPlatform()), patch.dict('sys.modules',
{'openvino': Mock()}):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
assert backend.get_name() == name


def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use
# get_attn_backend

override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported data type
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported kv cache data type
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported block size
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported head size
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, torch.float16, None, 16, False)
get_attn_backend(16, torch.float16, None, 16, False)
139 changes: 12 additions & 127 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname

logger = init_logger(__name__)

Expand Down Expand Up @@ -114,83 +114,19 @@ def _cached_get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend

backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free, use_v1)
if backend == _Backend.FLASH_ATTN:
logger.info("Using Flash Attention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
elif backend == _Backend.ROCM_FLASH:
logger.info("Using ROCmFlashAttention backend.")
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert current_platform.is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO Attention backend.")
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
return OpenVINOAttentionBackend
elif backend == _Backend.IPEX:
assert current_platform.is_xpu(), RuntimeError(
"IPEX attention backend is only used for the XPU device.")
logger.info("Using IPEX attention backend.")
from vllm.attention.backends.ipex_attn import IpexAttnBackend
return IpexAttnBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.HPU_ATTN:
logger.info("Using HPUAttention backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
else:
raise ValueError("Invalid attention backend.")


def which_attn_to_use(head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN

# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return _Backend.NO_ATTENTION
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend

# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
Expand All @@ -201,64 +137,13 @@ def which_attn_to_use(head_size: int,
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)

# get device-specific default attn_backend
default_backend = current_platform.get_default_attn_backend(
selected_backend)
if default_backend is not None:
return default_backend

if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1

# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if not current_platform.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
selected_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
selected_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
selected_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
selected_backend = _Backend.XFORMERS

# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
selected_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
selected_backend = _Backend.XFORMERS

return selected_backend
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")
return resolve_obj_by_qualname(attention_cls)


@contextmanager
Expand Down
7 changes: 5 additions & 2 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
logger.info("Using Torch SDPA backend.")
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
Expand Down
Loading

0 comments on commit 38a0be3

Please sign in to comment.