diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1ea193ae7c3..1a539ebd75c 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -35,11 +35,6 @@ def __init__(self, model_runner: ModelRunner): model_runner.model_config.num_attention_heads // model_runner.tp_size ) - if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): - self.reduce_dtype = torch.float32 - else: - self.reduce_dtype = torch.float16 - self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] @@ -53,9 +48,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" if forward_batch.forward_mode.is_decode(): - start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) - start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - attn_logits = torch.empty( ( forward_batch.batch_size, @@ -67,13 +59,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): device=self.device, ) - max_seq_len = torch.max(forward_batch.seq_lens).item() max_extend_len = None else: - start_loc = attn_logits = max_seq_len = None + attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len + self.forward_metadata = attn_logits, max_extend_len def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -96,9 +87,7 @@ def init_forward_metadata_capture_cuda_graph( ): # NOTE: encoder_lens expected to be zeros or None self.forward_metadata = ( - self.cuda_graph_start_loc, self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, None, ) @@ -137,7 +126,7 @@ def forward_extend( layer, forward_batch.out_cache_loc, k, v ) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + _, max_extend_len = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -175,7 +164,7 @@ def forward_decode( else: o = torch.empty_like(q) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + attn_logits, _ = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -189,10 +178,8 @@ def forward_decode( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.req_to_token_pool.req_to_token, forward_batch.req_pool_indices, - start_loc, forward_batch.seq_lens, attn_logits, - max_seq_len, self.num_kv_splits, layer.scaling, layer.logit_cap, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 9eeb98a2963..d2e856ca605 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -19,6 +19,9 @@ # Adapted from # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +import logging + import triton import triton.language as tl @@ -26,6 +29,13 @@ is_hip_ = is_hip() +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. +logger.warn( + "The following error message 'operation scheduled before its operands' can be ignored." +) + @triton.jit def tanh(x): @@ -166,7 +176,6 @@ def _decode_att_m_fwd( Req_to_tokens, B_req_idx, B_Seqlen, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd( Req_to_tokens, B_req_idx, B_Seqlen, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -556,7 +564,6 @@ def decode_attention_fwd_normal( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap=0.0, @@ -569,7 +576,6 @@ def decode_attention_fwd_normal( req_to_token, b_req_idx, b_seq_len, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -586,7 +592,6 @@ def decode_attention_fwd_grouped( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap=0.0, @@ -599,7 +604,6 @@ def decode_attention_fwd_grouped( req_to_token, b_req_idx, b_seq_len, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -614,10 +618,8 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap=0.0, @@ -636,7 +638,6 @@ def decode_attention_fwd( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, @@ -652,7 +653,6 @@ def decode_attention_fwd( b_req_idx, b_seq_len, attn_logits, - max_len_in_batch, num_kv_splits, sm_scale, logit_cap, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index b7917345b5b..048d27c5658 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -196,7 +196,6 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( @@ -212,10 +211,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - seq_len, num_kv_splits, sm_scale, ) @@ -255,7 +252,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( @@ -273,7 +269,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): b_req_idx, b_seq_len, attn_logits, - seq_len, num_kv_splits, sm_scale, ) @@ -293,7 +288,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): b_req_idx, b_seq_len, attn_logits1, - seq_len, num_kv_splits, sm_scale, )