Skip to content

Commit

Permalink
Update neuron interface
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Jan 6, 2025
1 parent bb83429 commit 4195f8e
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 55 deletions.
67 changes: 32 additions & 35 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.attention.selector import get_attn_backend
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
Expand All @@ -16,78 +16,75 @@
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable."""
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""

override_backend_env_variable(monkeypatch, name)

if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend == "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend == "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend == "vllm.attention.backends.openvino.OpenVINOAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
if name == "FLASHINFER":
assert backend == "vllm.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
if name == "XFORMERS":
assert backend == "vllm.attention.backends.xformers.XFormersBackend"
else:
assert backend == "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name


def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use
# get_attn_backend

override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported data type
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported kv cache data type
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported block size
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported head size
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, torch.float16, None, 16, False)
get_attn_backend(16, torch.float16, None, 16, False)
27 changes: 8 additions & 19 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,12 @@ def _cached_get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend

attention_cls = which_attn_to_use(head_size, dtype, kv_cache_dtype,
block_size, is_attention_free, use_v1)
assert attention_cls != "", (
f"Invalid attention backend for {current_platform.device_name}")

return resolve_obj_by_qualname(attention_cls)


def which_attn_to_use(head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False) -> str:
"""Returns which flash attention backend to use."""
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend" # noqa: E501
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend

# Check whether a particular choice of backend was
# previously forced.
Expand All @@ -151,9 +138,11 @@ def which_attn_to_use(head_size: int,
selected_backend = backend_name_to_enum(backend_by_env_var)

# get device-specific attn_backend
return current_platform.get_attn_backend_cls(selected_backend, head_size,
dtype, kv_cache_dtype,
block_size, use_v1)
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
assert attention_cls != "", (
f"Invalid attention backend for {current_platform.device_name}")
return resolve_obj_by_qualname(attention_cls)


@contextmanager
Expand Down
80 changes: 79 additions & 1 deletion vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import TYPE_CHECKING, Optional

import torch

from vllm.logger import init_logger

from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -43,3 +45,79 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
return False

@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
elif selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}")

target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
target_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and \
kv_cache_dtype.startswith("fp8"):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
target_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
target_backend = _Backend.XFORMERS

# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
target_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
target_backend = _Backend.XFORMERS

if target_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"

logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

0 comments on commit 4195f8e

Please sign in to comment.