Skip to content

Commit

Permalink
[platform] fix attn backend for cuda
Browse files Browse the repository at this point in the history
Signed-off-by: Mengqing Cao <[email protected]>
  • Loading branch information
MengqingCao committed Jan 9, 2025
1 parent 7e83803 commit afb6c12
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
19 changes: 14 additions & 5 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import get_attn_backend
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL


@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
Expand Down Expand Up @@ -39,10 +46,12 @@ def test_env(name: str, device: str, monkeypatch):
False)
assert backend.get_name() == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == name
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
assert backend.get_name() == name


def test_flash_attn(monkeypatch):
Expand Down
3 changes: 1 addition & 2 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}")
Expand Down

0 comments on commit afb6c12

Please sign in to comment.