Skip to content

Commit

Permalink
[Kernel] Move attn_type to Attention.__init__() (#11690)
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <[email protected]>
  • Loading branch information
heheda12345 authored Jan 6, 2025
1 parent 32c9eff commit e20c92b
Show file tree
Hide file tree
Showing 18 changed files with 159 additions and 201 deletions.
100 changes: 49 additions & 51 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import torch

from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
Expand Down Expand Up @@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
max_dec_seq_len: int
max_enc_seq_len: int
num_blocks: int
attn_type: AttentionType


class TestResources(NamedTuple):
Expand Down Expand Up @@ -96,7 +96,6 @@ class TestResources(NamedTuple):
'''

scale: float
attn_backend: AttentionBackend
attn: Attention
kv_cache: torch.Tensor

Expand Down Expand Up @@ -129,16 +128,17 @@ class that Attention will automatically select when it is constructed.
'''

scale = float(1.0 / (test_pt.head_size**0.5))
attn_backend = make_backend(test_pt.backend_name)
attn = Attention(
test_pt.num_heads,
test_pt.head_size,
scale=scale,
prefix=f"{test_pt.attn_type}",
attn_type=test_pt.attn_type,
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(
scale, attn_backend, attn,
scale, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))

# Construct KV cache
Expand All @@ -148,7 +148,7 @@ class that Attention will automatically select when it is constructed.
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
return TestResources(scale, attn_backend, attn, kv_cache)
return TestResources(scale, attn, kv_cache)


def _encoder_attn_setup(
Expand Down Expand Up @@ -193,6 +193,7 @@ def _encoder_attn_setup(
_,
max_q_seq_len,
_,
_,
) = test_pt

scale = test_rsrcs.scale
Expand Down Expand Up @@ -301,6 +302,7 @@ def _decoder_attn_setup(
max_q_seq_len,
_,
_,
_,
) = test_pt

scale = test_rsrcs.scale
Expand Down Expand Up @@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
max_decoder_seq_len,
max_encoder_seq_len,
_,
_,
) = test_pt

scale = test_rsrcs.scale
Expand Down Expand Up @@ -622,7 +625,6 @@ def _run_encoder_attention_test(
& attn_metadata
'''
assert attn_metadata.num_decode_tokens == 0
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config):
Expand All @@ -635,14 +637,11 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
return attn.forward(
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)


def _run_decoder_self_attention_test(
Expand Down Expand Up @@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type = AttentionType.DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
Expand All @@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
kv_cache, attn_metadata)


def _run_encoder_decoder_cross_attention_test(
Expand Down Expand Up @@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
'''
assert decoder_test_params.packed_qkvo.packed_qkv is not None

attn_type = AttentionType.ENCODER_DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None:
Expand All @@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, key, value, kv_cache,
attn_metadata)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -839,7 +828,7 @@ def test_encoder_only(
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
max_enc_seq_len, 4096, AttentionType.ENCODER)

# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
Expand All @@ -855,7 +844,7 @@ def test_encoder_only(
# Shared prefill metadata structure

prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
None,
decoder_test_params=None,
Expand Down Expand Up @@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.ENCODER)
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096,
AttentionType.ENCODER_DECODER)
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.DECODER)

# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
enc_test_rsrcs = _make_test_resources(enc_test_pt)
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
dec_test_rsrcs = _make_test_resources(dec_test_pt)

# Construct encoder attention test params (only used
# during prefill)

enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)

# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
Expand All @@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)

# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
Expand All @@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
enc_dec_test_pt,
enc_dec_test_rsrcs,
block_base_addr=cross_block_base_addr)

# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
Expand All @@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(

# PREFILL: encoder attention

enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_test_pt,
vllm_config=vllm_config)

# - Is encoder attention result correct?
Expand All @@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: decoder self-attention test

prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)

# - Is prefill decoder self-attention correct?
Expand All @@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder/decoder cross-attention test

prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)

# - Is prefill encoder/decoder cross-attention correct?
Expand All @@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
# DECODE: build decode-phase attention metadata

decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
Expand All @@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test

decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)

# - Is decode-phase decoder self-attention correct?
Expand All @@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
# DECODE: encoder/decoder cross-attention test

decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)

# - Is decode-phase encoder/decoder cross-attention correct?
Expand Down
12 changes: 7 additions & 5 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

Expand Down Expand Up @@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(


def make_test_metadata(
attn_backend: AttentionBackend,
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[List[int]],
decoder_test_params: Optional[PhaseTestParameters],
Expand All @@ -815,7 +816,7 @@ def make_test_metadata(
Arguments:
* attn_backend: Backend for sourcing attention kernels
* attn_backend_name: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
Expand Down Expand Up @@ -882,6 +883,8 @@ def make_test_metadata(
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap

attn_backend_obj = make_backend(attn_backend.name)

if is_prompt:
# Prefill-phase scenario

Expand All @@ -902,8 +905,7 @@ def make_test_metadata(
context_lens,
encoder_seq_lens,
device=device)

return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
Expand Down Expand Up @@ -952,7 +954,7 @@ def make_test_metadata(
encoder_seq_lens,
device=device)

return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
raise NotImplementedError

Expand All @@ -246,7 +247,6 @@ def forward(
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
Loading

0 comments on commit e20c92b

Please sign in to comment.