diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 04c56a1e6a3bd..40516320b0a46 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -4,7 +4,7 @@ import torch from tests.kernels.utils import override_backend_env_variable -from vllm.attention.selector import which_attn_to_use +from vllm.attention.selector import get_attn_backend from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform from vllm.platforms.openvino import OpenVinoPlatform @@ -16,78 +16,75 @@ "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"]) @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) def test_env(name: str, device: str, monkeypatch): - """Test that the attention selector can be set via environment variable.""" + """Test that the attention selector can be set via environment variable. + Note that we do not test FlashAttn because it is the default backend. + """ override_backend_env_variable(monkeypatch, name) if device == "cpu": with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - assert backend == "vllm.attention.backends.torch_sdpa.TorchSDPABackend" + backend = get_attn_backend(16, torch.float16, torch.float16, 16, + False) + assert backend.get_name() == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.current_platform", RocmPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - assert backend == "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + backend = get_attn_backend(16, torch.float16, torch.float16, 16, + False) + assert backend.get_name() == "ROCM_FLASH" elif device == "openvino": with patch("vllm.attention.selector.current_platform", OpenVinoPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - assert backend == "vllm.attention.backends.openvino.OpenVINOAttentionBackend" # noqa: E501 + backend = get_attn_backend(16, torch.float16, torch.float16, 16, + False) + assert backend.get_name() == "OPENVINO" else: with patch("vllm.attention.selector.current_platform", CudaPlatform()): - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) - if name == "FLASHINFER": - assert backend == "vllm.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - if name == "XFORMERS": - assert backend == "vllm.attention.backends.xformers.XFormersBackend" - else: - assert backend == "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + backend = get_attn_backend(16, torch.float16, torch.float16, 16, + False) + assert backend.get_name() == name def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" # TODO: When testing for v1, pipe in `use_v1` as an argument to - # which_attn_to_use + # get_attn_backend override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=(7, 5)): - backend = which_attn_to_use(16, torch.float16, None, 16, False) - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported data type - backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False) - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = which_attn_to_use(16, torch.float16, "fp8", 16, False) - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" + backend = get_attn_backend(16, torch.float16, "fp8", 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported block size - backend = which_attn_to_use(16, torch.float16, None, 8, False) - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" + backend = get_attn_backend(16, torch.float16, None, 8, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(16, torch.float16, None, 16, False) - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Unsupported head size - backend = which_attn_to_use(17, torch.float16, None, 16, False) - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" + backend = get_attn_backend(17, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL # Attention-free models should bypass env and use PlaceholderAttention - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True) - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" + backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + assert backend.get_name() != STR_FLASH_ATTN_VAL def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): - which_attn_to_use(16, torch.float16, None, 16, False) + get_attn_backend(16, torch.float16, None, 16, False) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index ca6e3f6b63253..0ff007c87b1c9 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -114,25 +114,12 @@ def _cached_get_attn_backend( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - attention_cls = which_attn_to_use(head_size, dtype, kv_cache_dtype, - block_size, is_attention_free, use_v1) - assert attention_cls != "", ( - f"Invalid attention backend for {current_platform.device_name}") - - return resolve_obj_by_qualname(attention_cls) - - -def which_attn_to_use(head_size: int, - dtype: torch.dtype, - kv_cache_dtype: Optional[str], - block_size: int, - is_attention_free: bool, - use_v1: bool = False) -> str: - """Returns which flash attention backend to use.""" # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION if is_attention_free: - return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend" # noqa: E501 + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend # Check whether a particular choice of backend was # previously forced. @@ -151,9 +138,12 @@ def which_attn_to_use(head_size: int, selected_backend = backend_name_to_enum(backend_by_env_var) # get device-specific attn_backend - return current_platform.get_attn_backend_cls(selected_backend, head_size, - dtype, kv_cache_dtype, - block_size, use_v1) + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}") + return resolve_obj_by_qualname(attention_cls) @contextmanager diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index a4bbbd27c8a89..1da2d0f6c5cce 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,8 +1,10 @@ from typing import TYPE_CHECKING, Optional +import torch + from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import VllmConfig