Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Implement Cascade Attention #11635

Merged
merged 23 commits into from
Jan 1, 2025
Merged
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
238 changes: 225 additions & 13 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import numpy as np
import torch
import triton
import triton.language as tl

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.utils import cdiv


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -38,6 +42,10 @@ def get_kv_cache_shape(
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)


@dataclass
class FlashAttentionMetadata:
Expand All @@ -56,6 +64,15 @@ class FlashAttentionMetadata:
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor

# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
cu_prefix_kv_lens: Optional[torch.Tensor]
cu_suffix_kv_lens: Optional[torch.Tensor]

# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.


Expand Down Expand Up @@ -169,21 +186,216 @@ def forward(
)

# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
if not attn_metadata.use_cascade:
# Regular attention (common case).
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
)
return output

# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens,
cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
common_prefix_len=attn_metadata.common_prefix_len,
)

return output


def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
num_sms: int,
) -> bool:
# Too short common prefix. Probably not worth using cascade attention.
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.
if use_alibi or use_sliding_window:
return False
# Too few queries. Probably not worth using cascade attention.
num_reqs = len(query_lens)
if num_reqs < 8:
return False
comaniac marked this conversation as resolved.
Show resolved Hide resolved

# Heuristics to decide whether using cascade attention is beneficial.
# 1. When FlashDecoding is not used for normal attention, cascade attention
# is likely to be faster since it saves memory bandwidth.
num_queries_per_kv = num_query_heads // num_kv_heads
use_flash_decoding = (num_queries_per_kv > 1 and np.all(query_lens == 1)
and not use_sliding_window)
if not use_flash_decoding:
# Use cascade attention.
return True

# 2. When FlashDecoding is used for normal attention, it is not clear
# whether cascade attention is beneficial, because FlashDecoding can
# launch more CTAs than cascade attention.
# We use a simple performance model to compare the two methods.
# NOTE(woosuk): The performance model is very rough and may not be
# accurate.
num_tokens = num_reqs
q_tile_size = 128
kv_tile_size = 128
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
cascade_waves = cdiv(cascade_ctas, num_sms)
cascade_time = cascade_waves * num_prefix_tiles

flash_decoding_ctas = (num_reqs * num_kv_heads *
cdiv(num_queries_per_kv, q_tile_size))
flash_decoding_ctas *= num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time


def cascade_attention(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_query_lens: torch.Tensor,
max_query_len: int,
cu_prefix_query_lens: torch.Tensor,
cu_prefix_kv_lens: torch.Tensor,
cu_suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Tuple[int, int],
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window.
assert sliding_window == (-1, -1), (
"Cascade attention does not support sliding window.")

num_tokens = query.shape[0]
num_query_heads = query.shape[1]
head_size = query.shape[2]
block_size = key_cache.shape[-3]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0

# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
cu_seqlens_k=cu_prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
)

# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
)

# Merge prefix and suffix outputs.
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states[(num_tokens, num_query_heads)](
output,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
triton.next_power_of_2(head_size),
)


@triton.jit
def merge_attn_states(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)

p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse

head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)

p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
out = p_out * p_scale + s_out * s_scale
tl.store(output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask)
16 changes: 16 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,22 @@ def free(self, request: Request) -> None:
if block.ref_cnt == 0:
self.free_block_queue.append(block)

def get_num_common_prefix_blocks(
self,
request: Request,
num_requests: int,
) -> int:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
blocks = self.req_to_blocks[request.request_id]
num_common_blocks = 0
for block in blocks:
# FIXME(woosuk): For some reason, sometimes the ref_cnt is greater
# than the number of running requests. DEBUG this.
if block.ref_cnt >= num_requests:
num_common_blocks += 1
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
else:
break
return num_common_blocks

def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool.

Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ def schedule(self) -> "SchedulerOutput":
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) == len(self.running))

# Get the longest common prefix. This can be potentially used for
# cascade attention.
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))

# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
Expand All @@ -287,6 +295,7 @@ def schedule(self) -> "SchedulerOutput":
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
Expand Down Expand Up @@ -594,6 +603,7 @@ class SchedulerOutput:
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
num_common_prefix_blocks: int

preempted_req_ids: Set[str]
finished_req_ids: Set[str]
Expand Down
Loading
Loading