Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Take user preference in attention selector #4960

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from unittest.mock import patch

import pytest
import torch

from vllm.attention.selector import which_attn_to_use


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(name: str, device: str):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = name

if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "ROCM_FLASH"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == name

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_flash_attn():
"""Test FlashAttn validation."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
assert backend.name != "FLASH_ATTN"

# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
assert backend.name != "FLASH_ATTN"

# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
assert backend.name != "FLASH_ATTN"

# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_invalid_env():
"""Throw an exception if the backend name is invalid."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
1 change: 1 addition & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def forward(
)

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
output = flash_attn_varlen_func(
Expand Down
145 changes: 84 additions & 61 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,16 @@ def get_attn_backend(
kv_cache_dtype: Optional[str],
block_size: int,
) -> Type[AttentionBackend]:
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
"""Determine which attention backend to use and only import
the selected backend module.
"""
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

# We check it here not in _which_attn_to_use because we cannot know
# the head size until we import FlashAttentionBackend.
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size in supported_head_sizes:
logger.info("Using FlashAttention-2 backend.")
return FlashAttentionBackend
logger.info(
"Cannot use FlashAttention-2 backend for head size %d. "
"Using XFormers backend instead.", head_size)
backend = _Backend.XFORMERS

return FlashAttentionBackend
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
Expand All @@ -64,14 +56,15 @@ def get_attn_backend(
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")


def _which_attn_to_use(
def which_attn_to_use(
num_heads: int,
head_size: int,
num_kv_heads: int,
Expand All @@ -81,54 +74,84 @@ def _which_attn_to_use(
block_size: int,
) -> _Backend:
"""Returns which flash attention backend to use."""

# Default case.
selected_backend = _Backend.FLASH_ATTN

# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
backend_members = _Backend.__members__
if backend_by_env_var not in backend_members:
raise ValueError(
f"Invalid attention backend '{backend_by_env_var}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
selected_backend = _Backend[backend_by_env_var]

if is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

if is_hip():
# AMD GPUs.
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs.")
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS

if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS

if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
return _Backend.XFORMERS

if block_size % 16 != 0:
logger.info("Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
return _Backend.XFORMERS

if sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
return _Backend.XFORMERS

try:
import vllm_flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
"package is not found. `pip install vllm-flash-attn` for better "
"performance.")
return _Backend.XFORMERS

backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]

# Default case.
return _Backend.FLASH_ATTN
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
selected_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
selected_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
selected_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
selected_backend = _Backend.XFORMERS
elif sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
selected_backend = _Backend.XFORMERS

# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
try:
import vllm_flash_attn # noqa: F401

from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
selected_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance.")
selected_backend = _Backend.XFORMERS

return selected_backend
Loading