From ec52464ddeabcc70b1fd3117b93adfefd5cb7ed0 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Thu, 5 Dec 2024 01:50:28 +0800 Subject: [PATCH] MLA prefill w/o weight absorption (#2349) --- .../sglang/srt/layers/attention/__init__.py | 7 +- .../attention/double_sparsity_backend.py | 30 +++++--- .../layers/attention/flashinfer_backend.py | 25 +++++-- .../layers/attention/torch_native_backend.py | 30 +++++--- .../srt/layers/attention/triton_backend.py | 30 +++++--- .../attention/triton_ops/extend_attention.py | 3 + python/sglang/srt/layers/radix_attention.py | 6 +- python/sglang/srt/models/deepseek_v2.py | 71 ++++++++++++++++++- 8 files changed, 166 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index f5d573f5f7b..a70e9537bfe 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -52,12 +52,13 @@ def forward( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run forward on an attention layer.""" if forward_batch.forward_mode.is_decode(): - return self.forward_decode(q, k, v, layer, forward_batch) + return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) else: - return self.forward_extend(q, k, v, layer, forward_batch) + return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache) def forward_decode( self, @@ -66,6 +67,7 @@ def forward_decode( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run a forward for decode.""" raise NotImplementedError() @@ -77,6 +79,7 @@ def forward_extend( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, + save_kv_cache: bool = True, ): """Run a forward for extend.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 73c32df8f6e..856aa984c38 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -165,7 +165,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 1 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -181,9 +187,10 @@ def forward_extend( .expand(k.shape[0], -1, -1), ) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v, k_label - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v, k_label + ) ( start_loc, @@ -212,7 +219,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -242,9 +255,10 @@ def forward_decode( .expand(k.shape[0], -1, -1), ) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v, k_label - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v, k_label + ) # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num # and set a minimum value for sparse_decode diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 258659efa2a..f89bc2ccaa2 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -221,7 +221,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 0 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): prefill_wrapper_paged = self.prefill_wrappers_paged[ self._get_wrapper_idx(layer) @@ -237,7 +243,8 @@ def forward_extend( if not use_ragged: if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -270,12 +277,19 @@ def forward_extend( o, _ = merge_state(o1, s1, o2, s2) - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] cache_loc = ( @@ -286,7 +300,8 @@ def forward_decode( if k is not None: assert v is not None - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index 4ccad2216f7..5e7e0e66e22 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -216,16 +216,23 @@ def _run_sdpa_forward_decode( return output def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num @@ -249,7 +256,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -260,9 +273,10 @@ def forward_decode( else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index b9597b3ea41..1b7c4c46d26 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -114,7 +114,13 @@ def get_cuda_graph_seq_len_fill_value(self): return 1 def forward_extend( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: @@ -122,9 +128,10 @@ def forward_extend( else: o = torch.empty_like(q) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata self.extend_attention_fwd( @@ -146,7 +153,13 @@ def forward_extend( return o def forward_decode( - self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch + self, + q, + k, + v, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, ): # During torch.compile, there is a bug in rotary_emb that causes the # output value to have a 3D tensor shape. This reshapes the output correctly. @@ -160,9 +173,10 @@ def forward_decode( start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 56cc439c31e..b7afd62e723 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -284,6 +284,9 @@ def extend_attention_fwd( elif Lq == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 + elif Lq == 192: + BLOCK_DMODEL = 128 + BLOCK_DPE = 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 5d8c6470178..1df29ec68a9 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -48,11 +48,13 @@ def __init__( self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention - def forward(self, q, k, v, forward_batch: ForwardBatch): + def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True): if k is not None: # For cross-layer sharing, kv can be None assert v is not None k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - return forward_batch.attn_backend.forward(q, k, v, self, forward_batch) + return forward_batch.attn_backend.forward( + q, k, v, self, forward_batch, save_kv_cache + ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 424f86aec28..e83774ff55e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -453,7 +453,7 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = RadixAttention( + self.attn_mqa = RadixAttention( self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim, self.scaling, @@ -462,6 +462,15 @@ def __init__( v_head_dim=self.kv_lora_rank, ) + self.attn_mha = RadixAttention( + self.num_local_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + layer_id=layer_id, + v_head_dim=self.v_head_dim, + ) + self.w_kc = None self.w_vc = None self.w_scale = None @@ -471,6 +480,63 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + ) -> torch.Tensor: + # Use normal computation for prefill and use weight absorption for extend/decode + if ( + forward_batch.forward_mode.is_extend() + and forward_batch.extend_prefix_lens.sum() == 0 + ): + return self.forward_normal(positions, hidden_states, forward_batch) + else: + return self.forward_absorb(positions, hidden_states, forward_batch) + + def forward_normal( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope = kv[..., : self.qk_nope_head_dim] + v = kv[..., self.qk_nope_head_dim :] + k_pe = latent_cache[:, :, self.kv_lora_rank :] + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + + latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) + latent_cache[:, :, self.kv_lora_rank :] = k_pe + + # Save latent cache + forward_batch.token_to_kv_pool.set_kv_buffer( + self.attn_mha, forward_batch.out_cache_loc, latent_cache, None + ) + attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) + attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + def forward_absorb( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( @@ -508,7 +574,7 @@ def forward( q_input[..., self.kv_lora_rank :] = q_pe k_input[..., self.kv_lora_rank :] = k_pe - attn_output = self.attn(q_input, k_input, v_input, forward_batch) + attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.w_vc.dtype == torch.float8_e4m3fn: @@ -835,7 +901,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self_attn.w_vc = w_vc.contiguous().transpose(1, 2) if hasattr(self_attn.kv_b_proj, "weight_scale"): self_attn.w_scale = self_attn.kv_b_proj.weight_scale - del self_attn.kv_b_proj EntryClass = DeepseekV2ForCausalLM