Skip to content

Commit

Permalink
rebase all codes
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-yang-1 committed Sep 25, 2024
1 parent 8f527e2 commit 57c998b
Show file tree
Hide file tree
Showing 7 changed files with 964 additions and 2 deletions.
182 changes: 182 additions & 0 deletions python/sglang/srt/layers/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,185 @@ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
layer.logit_cap,
)
return o


class DoubleSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.triton_attention.decode_attention import (
decode_attention_fwd,
)
from sglang.srt.layers.triton_attention.extend_attention import (
extend_attention_fwd,
)
from sglang.srt.layers.triton_attention.sparse_decode_attention import (
decode_sparse_attention_fwd
)

super().__init__()

self.decode_attention_fwd = decode_attention_fwd
self.decode_sparse_attention_fwd = decode_sparse_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.num_head = model_runner.model_config.num_attention_heads

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32
else:
self.reduce_dtype = torch.float16

self.forward_metadata = None

self.cuda_graph_max_seq_len = model_runner.model_config.context_len

def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
"""Init auxiliary variables for triton attention backend."""

if input_metadata.forward_mode.is_decode():
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)

total_num_tokens = torch.sum(input_metadata.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
device="cuda",
)

max_seq_len = torch.max(input_metadata.seq_lens).item()
max_extend_len = None
#NOTE: Align sequence order with req_to_token order
ds_req_to_token = input_metadata.req_to_token_pool.req_to_token[input_metadata.req_pool_indices]
else:
start_loc = attn_logits = max_seq_len = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
ds_req_to_token = None

self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len, ds_req_to_token

def init_cuda_graph_state(self, max_bs: int):
#TODO(Andy): Support CUDA graph for double sparse attention
raise ValueError("Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph")
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len

self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
device="cuda",
)

def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)

def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)

def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

k_label = torch.gather(k, 2, input_metadata.sorted_channels[layer.layer_id].unsqueeze(0).expand(k.shape[0], -1, -1))

input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v, k_label
)

start_loc, attn_logits, max_seq_len, max_extend_len, ds_req_to_token = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o

def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

start_loc, attn_logits, max_seq_len, max_extend_len, ds_req_to_token = self.forward_metadata

k_label = torch.gather(k, 2, input_metadata.sorted_channels[layer.layer_id].unsqueeze(0).expand(k.shape[0], -1, -1))

input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v, k_label
)

# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
if max_seq_len < input_metadata.heavy_token_num or max_seq_len < input_metadata.sparse_decode_thresold:
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
start_loc,
input_metadata.seq_lens,
attn_logits,
max_seq_len,
layer.scaling,
layer.logit_cap,
)
else:
#TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
q_label = torch.gather(q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), 2, input_metadata.sorted_channels[layer.layer_id].unsqueeze(0).expand(q.shape[0], -1, -1))
self.decode_sparse_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
q_label,
input_metadata.token_to_kv_pool.get_label_buffer(layer.layer_id),
ds_req_to_token,
input_metadata.seq_lens,
max_seq_len,
layer.scaling,
layer.logit_cap,
input_metadata.heavy_token_num,
)

return o
Loading

0 comments on commit 57c998b

Please sign in to comment.