Skip to content

Commit

Permalink
Update neuron interface
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Jan 7, 2025
1 parent 12f43b4 commit 7e83803
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 54 deletions.
67 changes: 32 additions & 35 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.attention.selector import get_attn_backend
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
Expand All @@ -16,78 +16,75 @@
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable."""
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""

override_backend_env_variable(monkeypatch, name)

if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend == "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend == "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend == "vllm.attention.backends.openvino.OpenVINOAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
if name == "FLASHINFER":
assert backend == "vllm.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
if name == "XFORMERS":
assert backend == "vllm.attention.backends.xformers.XFormersBackend"
else:
assert backend == "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == name


def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use
# get_attn_backend

override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported data type
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported kv cache data type
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported block size
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Unsupported head size
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, torch.float16, None, 16, False)
get_attn_backend(16, torch.float16, None, 16, False)
28 changes: 9 additions & 19 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,12 @@ def _cached_get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend

attention_cls = which_attn_to_use(head_size, dtype, kv_cache_dtype,
block_size, is_attention_free, use_v1)
assert attention_cls != "", (
f"Invalid attention backend for {current_platform.device_name}")

return resolve_obj_by_qualname(attention_cls)


def which_attn_to_use(head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False) -> str:
"""Returns which flash attention backend to use."""
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend" # noqa: E501
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend

# Check whether a particular choice of backend was
# previously forced.
Expand All @@ -151,9 +138,12 @@ def which_attn_to_use(head_size: int,
selected_backend = backend_name_to_enum(backend_by_env_var)

# get device-specific attn_backend
return current_platform.get_attn_backend_cls(selected_backend, head_size,
dtype, kv_cache_dtype,
block_size, use_v1)
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")
return resolve_obj_by_qualname(attention_cls)


@contextmanager
Expand Down

0 comments on commit 7e83803

Please sign in to comment.