Skip to content

Commit

Permalink
Fix correctness issue for triton decoding kernel (#2479)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Dec 14, 2024
1 parent 5282a47 commit 2f9bd0f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
38 changes: 24 additions & 14 deletions python/sglang/srt/layers/attention/triton_ops/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
logger = logging.getLogger(__name__)

# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger.warn(
logger.warning(
"The following error message 'operation scheduled before its operands' can be ignored."
)

Expand Down Expand Up @@ -474,6 +474,7 @@ def _decode_grouped_att_m_fwd(
def _fwd_kernel_stage2(
Mid_O,
O,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
Expand All @@ -486,6 +487,8 @@ def _fwd_kernel_stage2(
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)

cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)

offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv

Expand All @@ -497,19 +500,24 @@ def _fwd_kernel_stage2(
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv

for split_kv_id in range(0, NUM_KV_SPLITS):
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)

old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)

e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv

e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max

tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
Expand All @@ -523,6 +531,7 @@ def _decode_softmax_reducev_fwd(
q,
o,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Expand All @@ -541,6 +550,7 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
Expand Down Expand Up @@ -580,7 +590,7 @@ def decode_attention_fwd_normal(
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)


def decode_attention_fwd_grouped(
Expand Down Expand Up @@ -608,7 +618,7 @@ def decode_attention_fwd_grouped(
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)


def decode_attention_fwd(
Expand Down
10 changes: 6 additions & 4 deletions test/srt/test_triton_attention_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def test_decode_attention(self):
for B, H_Q, H_KV, D in configs:
self._test_decode_attention_once(B, H_Q, H_KV, D)

def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16
seq_len = 128 # This represents the number of tokens already in the sequence
seq_len = S # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
Expand Down Expand Up @@ -300,6 +300,7 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))

def test_grouped_decode_attention(self):
seq_lens = [5, 100, 128, 500]
configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 64, 64),
Expand All @@ -309,8 +310,9 @@ def test_grouped_decode_attention(self):
(2, 128, 1, 576, 512),
]

for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V)
for S in seq_lens:
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)


if __name__ == "__main__":
Expand Down

0 comments on commit 2f9bd0f

Please sign in to comment.