Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLA prefill w/o weight absorption #2349

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def forward(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch)
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
else:
return self.forward_extend(q, k, v, layer, forward_batch)
return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)

def forward_decode(
self,
Expand All @@ -66,6 +67,7 @@ def forward_decode(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for decode."""
raise NotImplementedError()
Expand All @@ -77,6 +79,7 @@ def forward_extend(
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
"""Run a forward for extend."""
raise NotImplementedError()
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/double_sparsity_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,13 @@ def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
zhyncs marked this conversation as resolved.
Show resolved Hide resolved
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
Expand All @@ -181,9 +187,10 @@ def forward_extend(
.expand(k.shape[0], -1, -1),
)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)

(
start_loc,
Expand Down Expand Up @@ -212,7 +219,13 @@ def forward_extend(
return o

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand Down Expand Up @@ -242,9 +255,10 @@ def forward_decode(
.expand(k.shape[0], -1, -1),
)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v, k_label
)

# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
Expand Down
25 changes: 20 additions & 5 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ def get_cuda_graph_seq_len_fill_value(self):
return 0

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
prefill_wrapper_paged = self.prefill_wrappers_paged[
self._get_wrapper_idx(layer)
Expand All @@ -237,7 +243,8 @@ def forward_extend(
if not use_ragged:
if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)

o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
Expand Down Expand Up @@ -270,12 +277,19 @@ def forward_extend(

o, _ = merge_state(o1, s1, o2, s2)

forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
cache_loc = (
Expand All @@ -286,7 +300,8 @@ def forward_decode(

if k is not None:
assert v is not None
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)

o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
Expand Down
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/torch_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,23 @@ def _run_sdpa_forward_decode(
return output

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

use_gqa = layer.tp_q_head_num != layer.tp_k_head_num

Expand All @@ -249,7 +256,13 @@ def forward_extend(
return o

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand All @@ -260,9 +273,10 @@ def forward_decode(
else:
o = torch.empty_like(q)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

use_gqa = layer.tp_q_head_num != layer.tp_k_head_num

Expand Down
30 changes: 22 additions & 8 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,24 @@ def get_cuda_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
self.extend_attention_fwd(
Expand All @@ -146,7 +153,13 @@ def forward_extend(
return o

def forward_decode(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand All @@ -160,9 +173,10 @@ def forward_decode(

start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata

forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def extend_attention_fwd(
elif Lq == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
elif Lq == 192:
BLOCK_DMODEL = 128
BLOCK_DPE = 64
else:
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ def __init__(
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention

def forward(self, q, k, v, forward_batch: ForwardBatch):
def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)

return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache
)
71 changes: 68 additions & 3 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def __init__(
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

self.attn = RadixAttention(
self.attn_mqa = RadixAttention(
self.num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
self.scaling,
Expand All @@ -462,6 +462,15 @@ def __init__(
v_head_dim=self.kv_lora_rank,
)

self.attn_mha = RadixAttention(
self.num_local_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
v_head_dim=self.v_head_dim,
)

self.w_kc = None
self.w_vc = None
self.w_scale = None
Expand All @@ -471,6 +480,63 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)

def forward_normal(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe

latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe

# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output

def forward_absorb(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
Expand Down Expand Up @@ -508,7 +574,7 @@ def forward(
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe

attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

if self.w_vc.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -835,7 +901,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale"):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
del self_attn.kv_b_proj


EntryClass = DeepseekV2ForCausalLM
Loading