Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Move attn_type to Attention.__init__() #11690

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
raise NotImplementedError

Expand All @@ -246,7 +247,6 @@ def forward(
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
14 changes: 7 additions & 7 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
Expand Down Expand Up @@ -350,6 +351,12 @@ def __init__(
active_head_range=self.blocksparse_params.active_head_range,
)

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -359,7 +366,6 @@ def forward(
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand All @@ -375,12 +381,6 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -627,6 +628,7 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type

def forward(
self,
Expand All @@ -637,7 +639,6 @@ def forward(
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand All @@ -659,6 +660,7 @@ def forward(

assert output is not None, "Output tensor must be provided."

attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
Expand Down
15 changes: 7 additions & 8 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -764,6 +765,12 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -773,18 +780,10 @@ def forward(
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")

num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
Expand Down
13 changes: 7 additions & 6 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
attn_type: str = AttentionType.DECODER,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
Expand Down Expand Up @@ -143,6 +144,12 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -152,7 +159,6 @@ def forward(
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand All @@ -166,11 +172,6 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

Expand Down
12 changes: 6 additions & 6 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -146,6 +147,11 @@ def __init__(
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")

def split_kv_cache(
self,
Expand All @@ -172,7 +178,6 @@ def forward(
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand All @@ -189,11 +194,6 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand Down
13 changes: 7 additions & 6 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down Expand Up @@ -141,6 +142,12 @@ def __init__(
# megacore mode will be None.
self.megacore_mode = "batch"

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -150,7 +157,6 @@ def forward(
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand All @@ -168,11 +174,6 @@ def forward(
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
Expand Down
14 changes: 7 additions & 7 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -397,6 +398,12 @@ def __init__(
self.attn_func = _sdpa_attention
logger.debug("Using naive attention in ROCmBackend")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
Expand All @@ -414,7 +421,6 @@ def forward(
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand All @@ -432,12 +438,6 @@ def forward(
"""
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -421,6 +422,7 @@ def __init__(
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
self.attn_type = attn_type

def forward(
self,
Expand All @@ -431,7 +433,6 @@ def forward(
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand All @@ -448,6 +449,7 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand All @@ -405,6 +406,8 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.attn_type = attn_type

def forward(
self,
query: torch.Tensor,
Expand All @@ -414,7 +417,6 @@ def forward(
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down Expand Up @@ -468,7 +470,7 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""

attn_type = self.attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if (attn_type == AttentionType.ENCODER
Expand Down
Loading
Loading