Skip to content

Commit

Permalink
[Kernel] Use out arg in flash_attn_varlen_func (#10811)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Dec 2, 2024
1 parent b795477 commit 073a4bd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG fdf6d72b48aea41f4ae6a89139a453dae554abc8
GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
20 changes: 17 additions & 3 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand All @@ -81,6 +82,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
Expand Down Expand Up @@ -116,17 +118,22 @@ def test_flash_attn_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

q = query.unsqueeze(1)
out = torch.empty_like(q) if use_out else None
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
q=q,
k_cache=key_cache,
v_cache=value_cache,
out=out,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1)
)
output = output if not use_out else out
output = output.squeeze(1)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
Expand All @@ -141,7 +148,10 @@ def test_flash_attn_with_paged_kv(
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
Expand All @@ -151,6 +161,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_varlen_with_paged_kv(
use_out: bool,
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
Expand Down Expand Up @@ -197,10 +208,12 @@ def test_varlen_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

out = torch.empty_like(query) if use_out else None
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
Expand All @@ -211,6 +224,7 @@ def test_varlen_with_paged_kv(
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
output = output if not use_out else out

ref_output = ref_paged_attn(
query=query,
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,12 @@ def unified_v1_flash_attention(
v_scale,
)

attn_output = flash_attn_varlen_func(
# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
Expand All @@ -220,8 +222,6 @@ def unified_v1_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
# TODO(woosuk): Remove this unnecessary copy.
output[:num_actual_tokens].copy_(attn_output)


def unified_v1_flash_attention_fake(
Expand Down

0 comments on commit 073a4bd

Please sign in to comment.