Skip to content

Commit

Permalink
support 2 kernels for mixed_chunk_prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
lucky9-cyou committed Dec 22, 2024
1 parent 21e9e63 commit 9e97ab9
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 4 deletions.
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,15 @@ def forward_extend(
):
"""Run a forward for extend."""
raise NotImplementedError()

def forward_mixed(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for mixed."""
raise NotImplementedError()
170 changes: 166 additions & 4 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,58 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
encoder_lens=forward_batch.encoder_lens,
)
self.forward_metadata = (self.decode_wrappers,)
elif forward_batch.forward_mode.is_mixed():
decode_idx = 0
for i in range(0, len(forward_batch.extend_seq_lens_cpu)):
if forward_batch.extend_seq_lens[i] == 1:
decode_idx = i
break
# decode_idx = max(1, decode_idx)

running_bs = forward_batch.batch_size - decode_idx

forward_batch.running_bs = running_bs

extend_bs = forward_batch.batch_size - running_bs

if forward_batch.extend_num_tokens - running_bs >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu[:extend_bs])
else:
use_ragged = False
extend_no_prefix = False


req_pool_indices_extend = forward_batch.req_pool_indices[:extend_bs]
seq_lens_extend = forward_batch.seq_lens[:extend_bs]
seq_lens_sum_extend = forward_batch.seq_lens[:extend_bs].sum().item()
prefix_lens_extend = forward_batch.extend_prefix_lens[:extend_bs]
encoder_lens_extend = forward_batch.encoder_lens[:extend_bs] if forward_batch.encoder_lens is not None else None


self.indices_updater_prefill.update(
req_pool_indices_extend,
seq_lens_extend,
seq_lens_sum_extend,
prefix_lens_extend,
use_ragged=use_ragged,
encoder_lens=encoder_lens_extend,
)


self.indices_updater_decode.decode_indices = extend_bs

seq_lens_sum_decode = forward_batch.seq_lens[extend_bs:].sum().item()
self.indices_updater_decode.update(
forward_batch.req_pool_indices[extend_bs:],
forward_batch.seq_lens[extend_bs:],
seq_lens_sum_decode,
decode_wrappers=None,
encoder_lens=forward_batch.encoder_lens[extend_bs:] if forward_batch.encoder_lens is not None else None,
)

self.forward_metadata = (use_ragged, extend_no_prefix, self.decode_wrappers)

else:
prefix_lens = forward_batch.extend_prefix_lens

Expand Down Expand Up @@ -220,6 +272,102 @@ def init_forward_metadata_replay_cuda_graph(
def get_cuda_graph_seq_len_fill_value(self):
return 0

def forward_mixed(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer)
]

running_bs = forward_batch.running_bs

use_ragged, extend_no_prefix, _= self.forward_metadata

extend_tokens = forward_batch.extend_num_tokens - running_bs
cache_loc_extend = (
forward_batch.out_cache_loc[:extend_tokens]
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc[:extend_tokens]
)

k_extend = k[:extend_tokens]
v_extend = v[:extend_tokens]
q_extend = q[:extend_tokens]

if not use_ragged:
if k_extend is not None:
assert v_extend is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc_extend, k_extend, v_extend)

o = prefill_wrapper_paged.forward(
q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=layer.logit_cap,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_extend.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
v_extend.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)

if extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q_extend.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)

o, _ = merge_state(o1, s1, o2, s2)

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc_extend, k_extend, v_extend)
o = o.view(-1, layer.tp_q_head_num * layer.head_dim)
decode_wrapper = self.forward_metadata[2][self._get_wrapper_idx(layer)]
cache_loc_decode =(
forward_batch.out_cache_loc[extend_tokens:]
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc[extend_tokens:]
)
k_decode = k[extend_tokens:]
v_decode = v[extend_tokens:]
q_decode = q[extend_tokens:]

if k_decode is not None:
assert v_decode is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc_decode, k_decode, v_decode)
o_decode = decode_wrapper.forward(
q_decode.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)

o_decode = o_decode.view(-1, layer.tp_q_head_num * layer.head_dim)

return torch.cat((o, o_decode), 0)




def forward_extend(
self,
q,
Expand All @@ -229,6 +377,8 @@ def forward_extend(
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if forward_batch.forward_mode.is_mixed():
return self.forward_mixed(q, k, v, layer, forward_batch, save_kv_cache)
prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer)
]
Expand Down Expand Up @@ -346,6 +496,9 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.decode_wrappers = attn_backend.decode_wrappers

# used for mixed mode
self.decode_indices = 0

# Dispatch
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
Expand Down Expand Up @@ -458,8 +611,17 @@ def call_begin_forward(
kv_start_idx: torch.Tensor,
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
if self.decode_indices > 0:
kv_indptr[1 + self.decode_indices] = 0
kv_indptr[2 + self.decode_indices : 2 + self.decode_indices + bs] = torch.cumsum(
paged_kernel_lens, dim=0
)
kv_indptr_decode = kv_indptr[1 + self.decode_indices : 2 + self.decode_indices + bs]
self.decode_indices = 0
else:
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr_decode = kv_indptr[:bs + 1]

kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
Expand All @@ -468,15 +630,15 @@ def call_begin_forward(
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_indptr_decode,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)

wrapper.end_forward()
wrapper.begin_forward(
kv_indptr,
kv_indptr_decode,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ class ForwardBatch:
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False

# For mixed chunked prefill
running_bs: int = 0

def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
Expand Down

0 comments on commit 9e97ab9

Please sign in to comment.