From 2f9bd0fafd7bfe9f8c085a5f482635c8638accc6 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 14 Dec 2024 16:50:54 +0800 Subject: [PATCH] Fix correctness issue for triton decoding kernel (#2479) --- .../attention/triton_ops/decode_attention.py | 38 ++++++++++++------- test/srt/test_triton_attention_kernels.py | 10 +++-- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index d2e856ca605..469ab5ed242 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. -logger.warn( +logger.warning( "The following error message 'operation scheduled before its operands' can be ignored." ) @@ -474,6 +474,7 @@ def _decode_grouped_att_m_fwd( def _fwd_kernel_stage2( Mid_O, O, + B_Seqlen, stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -486,6 +487,8 @@ def _fwd_kernel_stage2( cur_batch = tl.program_id(0) cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv @@ -497,19 +500,24 @@ def _fwd_kernel_stage2( offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv for split_kv_id in range(0, NUM_KV_SPLITS): - tv = tl.load( - Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 - ) - tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) - n_e_max = tl.maximum(tlogic, e_max) + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - old_scale = tl.exp(e_max - n_e_max) - acc *= old_scale - exp_logic = tl.exp(tlogic - n_e_max) - acc += exp_logic * tv + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) - e_sum = e_sum * old_scale + exp_logic - e_max = n_e_max + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max tl.store( O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, @@ -523,6 +531,7 @@ def _decode_softmax_reducev_fwd( q, o, v_buffer, + b_seq_len, num_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] @@ -541,6 +550,7 @@ def _decode_softmax_reducev_fwd( _fwd_kernel_stage2[grid]( logits, o, + b_seq_len, logits.stride(0), logits.stride(1), logits.stride(2), @@ -580,7 +590,7 @@ def decode_attention_fwd_normal( sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) def decode_attention_fwd_grouped( @@ -608,7 +618,7 @@ def decode_attention_fwd_grouped( sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) def decode_attention_fwd( diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 048d27c5658..2398af9b0a7 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -232,9 +232,9 @@ def test_decode_attention(self): for B, H_Q, H_KV, D in configs: self._test_decode_attention_once(B, H_Q, H_KV, D) - def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): + def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): dtype = torch.bfloat16 - seq_len = 128 # This represents the number of tokens already in the sequence + seq_len = S # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) num_kv_splits = 8 @@ -300,6 +300,7 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) def test_grouped_decode_attention(self): + seq_lens = [5, 100, 128, 500] configs = [ (2, 16, 16, 64, 64), (2, 16, 1, 64, 64), @@ -309,8 +310,9 @@ def test_grouped_decode_attention(self): (2, 128, 1, 576, 512), ] - for B, H_Q, H_KV, D, D_V in configs: - self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V) + for S in seq_lens: + for B, H_Q, H_KV, D, D_V in configs: + self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V) if __name__ == "__main__":