Skip to content

Commit

Permalink
MLA prefill w/o weight absorption (#2349)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Dec 4, 2024
1 parent eb0c1f5 commit ec52464
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 36 deletions.
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def forward(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch)
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
else:
return self.forward_extend(q, k, v, layer, forward_batch)
return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)

def forward_decode(
self,
Expand All @@ -66,6 +67,7 @@ def forward_decode(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for decode."""
raise NotImplementedError()
Expand All @@ -77,6 +79,7 @@ def forward_extend(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for extend."""
raise NotImplementedError()
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/double_sparsity_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,13 @@ def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
Expand All @@ -181,9 +187,10 @@ def forward_extend(
.expand(k.shape[0], -1, -1),
)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)

(
start_loc,
Expand Down Expand Up @@ -212,7 +219,13 @@ def forward_extend(
return o

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand Down Expand Up @@ -242,9 +255,10 @@ def forward_decode(
.expand(k.shape[0], -1, -1),
)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)

# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
Expand Down
25 changes: 20 additions & 5 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ def get_cuda_graph_seq_len_fill_value(self):
return 0

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer)
Expand All @@ -237,7 +243,8 @@ def forward_extend(
if not use_ragged:
if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)

o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
Expand Down Expand Up @@ -270,12 +277,19 @@ def forward_extend(

o, _ = merge_state(o1, s1, o2, s2)

forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
cache_loc = (
Expand All @@ -286,7 +300,8 @@ def forward_decode(

if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)

o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
Expand Down
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/torch_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,23 @@ def _run_sdpa_forward_decode(
return output

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

use_gqa = layer.tp_q_head_num != layer.tp_k_head_num

Expand All @@ -249,7 +256,13 @@ def forward_extend(
return o

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand All @@ -260,9 +273,10 @@ def forward_decode(
else:
o = torch.empty_like(q)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

use_gqa = layer.tp_q_head_num != layer.tp_k_head_num

Expand Down
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,24 @@ def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
self.extend_attention_fwd(
Expand All @@ -146,7 +153,13 @@ def forward_extend(
return o

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand All @@ -160,9 +173,10 @@ def forward_decode(

start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def extend_attention_fwd(
elif Lq == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
elif Lq == 192:
BLOCK_DMODEL = 128
BLOCK_DPE = 64
else:
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ def __init__(
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention

def forward(self, q, k, v, forward_batch: ForwardBatch):
def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)

return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache
)
71 changes: 68 additions & 3 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def __init__(
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

self.attn = RadixAttention(
self.attn_mqa = RadixAttention(
self.num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
self.scaling,
Expand All @@ -462,6 +462,15 @@ def __init__(
v_head_dim=self.kv_lora_rank,
)

self.attn_mha = RadixAttention(
self.num_local_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
v_head_dim=self.v_head_dim,
)

self.w_kc = None
self.w_vc = None
self.w_scale = None
Expand All @@ -471,6 +480,63 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)

def forward_normal(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe

latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe

# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output

def forward_absorb(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
Expand Down Expand Up @@ -508,7 +574,7 @@ def forward(
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe

attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

if self.w_vc.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -835,7 +901,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
del self_attn.kv_b_proj


EntryClass = DeepseekV2ForCausalLM

0 comments on commit ec52464

Please sign in to comment.