Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Move attn_type to Attention.__init__() #11690

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 49 additions & 51 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import torch

from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
Expand Down Expand Up @@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
max_dec_seq_len: int
max_enc_seq_len: int
num_blocks: int
attn_type: AttentionType


class TestResources(NamedTuple):
Expand Down Expand Up @@ -96,7 +96,6 @@ class TestResources(NamedTuple):
'''

scale: float
attn_backend: AttentionBackend
attn: Attention
kv_cache: torch.Tensor

Expand Down Expand Up @@ -129,16 +128,17 @@ class that Attention will automatically select when it is constructed.
'''

scale = float(1.0 / (test_pt.head_size**0.5))
attn_backend = make_backend(test_pt.backend_name)
attn = Attention(
test_pt.num_heads,
test_pt.head_size,
scale=scale,
prefix=f"{test_pt.attn_type}",
attn_type=test_pt.attn_type,
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(
scale, attn_backend, attn,
scale, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))

# Construct KV cache
Expand All @@ -148,7 +148,7 @@ class that Attention will automatically select when it is constructed.
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
return TestResources(scale, attn_backend, attn, kv_cache)
return TestResources(scale, attn, kv_cache)


def _encoder_attn_setup(
Expand Down Expand Up @@ -193,6 +193,7 @@ def _encoder_attn_setup(
_,
max_q_seq_len,
_,
_,
) = test_pt

scale = test_rsrcs.scale
Expand Down Expand Up @@ -301,6 +302,7 @@ def _decoder_attn_setup(
max_q_seq_len,
_,
_,
_,
) = test_pt

scale = test_rsrcs.scale
Expand Down Expand Up @@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
max_decoder_seq_len,
max_encoder_seq_len,
_,
_,
) = test_pt

scale = test_rsrcs.scale
Expand Down Expand Up @@ -622,7 +625,6 @@ def _run_encoder_attention_test(
& attn_metadata
'''
assert attn_metadata.num_decode_tokens == 0
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config):
Expand All @@ -635,14 +637,11 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
return attn.forward(
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)


def _run_decoder_self_attention_test(
Expand Down Expand Up @@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type = AttentionType.DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
Expand All @@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
kv_cache, attn_metadata)


def _run_encoder_decoder_cross_attention_test(
Expand Down Expand Up @@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
'''
assert decoder_test_params.packed_qkvo.packed_qkv is not None

attn_type = AttentionType.ENCODER_DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None:
Expand All @@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, key, value, kv_cache,
attn_metadata)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -839,7 +828,7 @@ def test_encoder_only(
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
max_enc_seq_len, 4096, AttentionType.ENCODER)

# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
Expand All @@ -855,7 +844,7 @@ def test_encoder_only(
# Shared prefill metadata structure

prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
None,
decoder_test_params=None,
Expand Down Expand Up @@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.ENCODER)
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096,
AttentionType.ENCODER_DECODER)
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.DECODER)

# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
enc_test_rsrcs = _make_test_resources(enc_test_pt)
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
dec_test_rsrcs = _make_test_resources(dec_test_pt)

# Construct encoder attention test params (only used
# during prefill)

enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)

# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
Expand All @@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)

# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
Expand All @@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
enc_dec_test_pt,
enc_dec_test_rsrcs,
block_base_addr=cross_block_base_addr)

# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
Expand All @@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(

# PREFILL: encoder attention

enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_test_pt,
vllm_config=vllm_config)

# - Is encoder attention result correct?
Expand All @@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: decoder self-attention test

prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)

# - Is prefill decoder self-attention correct?
Expand All @@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder/decoder cross-attention test

prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)

# - Is prefill encoder/decoder cross-attention correct?
Expand All @@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
# DECODE: build decode-phase attention metadata

decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
Expand All @@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test

decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)

# - Is decode-phase decoder self-attention correct?
Expand All @@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
# DECODE: encoder/decoder cross-attention test

decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)

# - Is decode-phase encoder/decoder cross-attention correct?
Expand Down
12 changes: 7 additions & 5 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

Expand Down Expand Up @@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(


def make_test_metadata(
attn_backend: AttentionBackend,
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[List[int]],
decoder_test_params: Optional[PhaseTestParameters],
Expand All @@ -815,7 +816,7 @@ def make_test_metadata(

Arguments:

* attn_backend: Backend for sourcing attention kernels
* attn_backend_name: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
Expand Down Expand Up @@ -882,6 +883,8 @@ def make_test_metadata(
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap

attn_backend_obj = make_backend(attn_backend.name)

if is_prompt:
# Prefill-phase scenario

Expand All @@ -902,8 +905,7 @@ def make_test_metadata(
context_lens,
encoder_seq_lens,
device=device)

return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
Expand Down Expand Up @@ -952,7 +954,7 @@ def make_test_metadata(
encoder_seq_lens,
device=device)

return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
raise NotImplementedError

Expand All @@ -246,7 +247,6 @@ def forward(
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
Loading
Loading