Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Implement Cascade Attention #11635

Merged
merged 23 commits into from
Jan 1, 2025
Merged
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
238 changes: 225 additions & 13 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import numpy as np
import torch
import triton
import triton.language as tl

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func


Expand Down Expand Up @@ -38,6 +42,10 @@ def get_kv_cache_shape(
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)


@dataclass
class FlashAttentionMetadata:
Expand All @@ -56,6 +64,15 @@ class FlashAttentionMetadata:
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor

# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
cu_prefix_kv_lens: Optional[torch.Tensor]
cu_suffix_kv_lens: Optional[torch.Tensor]

# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.


Expand Down Expand Up @@ -169,21 +186,216 @@ def forward(
)

# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
if not attn_metadata.use_cascade:
# Regular attention (common case).
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
)
return output

# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens,
cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
common_prefix_len=attn_metadata.common_prefix_len,
)

return output


def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
num_sms: int,
) -> bool:
# Too short common prefix. Probably not worth using cascade attention.
# NOTE(woosuk): This is the common case. We should return False as soon as
# possible to avoid any unnecessary computation.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
if common_prefix_len < 256:
return False
# Cascade attention is currently not supported with these variants.
if use_alibi or use_sliding_window:
return False
# Too few queries. Probably not worth using cascade attention.
num_reqs = len(query_lens)
if num_reqs < 8:
return False
comaniac marked this conversation as resolved.
Show resolved Hide resolved

# Heuristics to decide whether using cascade attention is beneficial.
# 1. When FlashDecoding is not used for normal attention, cascade attention
# is likely to be faster since it saves memory bandwidth.
num_queries_per_kv = num_query_heads // num_kv_heads
use_flash_decoding = (num_queries_per_kv > 1 and np.all(query_lens == 1)
and not use_sliding_window)
if not use_flash_decoding:
# Use cascade attention.
return True

# 2. When FlashDecoding is used for normal attention, it is not clear
# whether cascade attention is beneficial, because FlashDecoding can
# launch more CTAs than cascade attention.
# We use a simple performance model to compare the two methods.
# NOTE(woosuk): The performance model is very rough and may not be
# accurate.
num_tokens = num_reqs
q_tile_size = 128
kv_tile_size = 128
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
cascade_waves = cdiv(cascade_ctas, num_sms)
cascade_time = cascade_waves * num_prefix_tiles

flash_decoding_ctas = (num_reqs * num_kv_heads *
cdiv(num_queries_per_kv, q_tile_size))
flash_decoding_ctas *= num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time


def cascade_attention(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_query_lens: torch.Tensor,
max_query_len: int,
cu_prefix_query_lens: torch.Tensor,
cu_prefix_kv_lens: torch.Tensor,
cu_suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Tuple[int, int],
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window.
assert sliding_window == (-1, -1), (
"Cascade attention does not support sliding window.")

num_tokens = query.shape[0]
num_query_heads = query.shape[1]
head_size = query.shape[2]
block_size = key_cache.shape[-3]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0

# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
cu_seqlens_k=cu_prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
)

# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
)

# Merge prefix and suffix outputs.
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states[(num_tokens, num_query_heads)](
output,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
triton.next_power_of_2(head_size),
)


@triton.jit
def merge_attn_states(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)

p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse

head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)

p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
out = p_out * p_scale + s_out * s_scale
tl.store(output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask)
50 changes: 49 additions & 1 deletion vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request
from vllm.v1.request import Request, RequestStatus

logger = init_logger(__name__)

Expand Down Expand Up @@ -278,6 +278,54 @@ def free(self, request: Request) -> None:
if block.ref_cnt == 0:
self.free_block_queue.append(block)

def get_num_common_prefix_blocks(
self,
request: Request,
num_running_requests: int,
) -> int:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state.

The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
`ref_cnt` equals the total number of requests in the RUNNING state.

NOTE(woosuk): The number of requests in the RUNNING state is **greater
than or equal to** the number of requests scheduled in the current step.
This is because the RUNNING state only indicates that:
1. The request has not yet finished, and
2. The request holds its blocks unfreed.

While all scheduled requests must be in the RUNNING state, the inverse
is not necessarily true. There may be RUNNING requests that are not
scheduled in the current step.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

This can result in an edge case where the number of common prefix blocks
is 0, even though all scheduled requests share a common prefix. This
occurs because there may be unscheduled RUNNING requests that do not
share the common prefix. Currently, this case cannot be easily detected,
so the function returns 0 in such cases.

Args:
request: Any request in the RUNNING state, used to identify the
common prefix blocks.
num_running_requests: The total number of requests in the RUNNING
state. This can be different from the number of scheduled
requests in the current step.

Returns:
int: The number of common prefix blocks.
"""
assert request.status == RequestStatus.RUNNING
blocks = self.req_to_blocks[request.request_id]
num_common_blocks = 0
for block in blocks:
if block.ref_cnt == num_running_requests:
num_common_blocks += 1
else:
break
return num_common_blocks

def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool.

Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ def schedule(self) -> "SchedulerOutput":
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) == len(self.running))

# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))

# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
Expand All @@ -287,6 +295,7 @@ def schedule(self) -> "SchedulerOutput":
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids=preempted_req_ids,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
Expand Down Expand Up @@ -594,6 +603,7 @@ class SchedulerOutput:
num_scheduled_tokens: Dict[str, int]
total_num_scheduled_tokens: int
scheduled_encoder_inputs: Dict[str, List[int]]
num_common_prefix_blocks: int

preempted_req_ids: Set[str]
finished_req_ids: Set[str]
Expand Down
Loading
Loading