Skip to content

Commit

Permalink
[platform] Allow platform specify attention backend
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiyuan committed Dec 30, 2024
1 parent 0aa38d1 commit e06ad57
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
9 changes: 9 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def _cached_get_attn_backend(
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
else:
# If the backend is not specified, it may be a plugin platform. Use the
# default backend impl from it instead.
impl = current_platform.get_default_attn_backend_impl()
if impl:
assert callable(impl), (
"The default attention backend implementation is not callable, "
f"platform: {current_platform.device_name}")
return impl

raise ValueError("Invalid attention backend.")


Expand Down
7 changes: 6 additions & 1 deletion vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import platform
import random
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -116,6 +116,11 @@ def get_default_attn_backend(cls, selected_backend: _Backend):
"""Get the default attention backend of a device."""
return None

@classmethod
def get_default_attn_backend_impl(cls) -> Optional[Callable]:
"""Get the default attention backend implementation of a device."""
return None

@classmethod
def get_device_capability(
cls,
Expand Down

0 comments on commit e06ad57

Please sign in to comment.