Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused vars in the triton backend #2401

Merged
merged 3 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ def __init__(self, model_runner: ModelRunner):
model_runner.model_config.num_attention_heads // model_runner.tp_size
)

if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32
else:
self.reduce_dtype = torch.float16

self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

Expand All @@ -53,9 +48,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""

if forward_batch.forward_mode.is_decode():
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)

attn_logits = torch.empty(
(
forward_batch.batch_size,
Expand All @@ -67,13 +59,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
device=self.device,
)

max_seq_len = torch.max(forward_batch.seq_lens).item()
max_extend_len = None
else:
start_loc = attn_logits = max_seq_len = None
attn_logits = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()

self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
self.forward_metadata = attn_logits, max_extend_len

def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
Expand All @@ -96,9 +87,7 @@ def init_forward_metadata_capture_cuda_graph(
):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)

Expand Down Expand Up @@ -137,7 +126,7 @@ def forward_extend(
layer, forward_batch.out_cache_loc, k, v
)

start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
_, max_extend_len = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
Expand Down Expand Up @@ -175,7 +164,7 @@ def forward_decode(
else:
o = torch.empty_like(q)

start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
attn_logits, _ = self.forward_metadata

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
Expand All @@ -189,10 +178,8 @@ def forward_decode(
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
start_loc,
forward_batch.seq_lens,
attn_logits,
max_seq_len,
self.num_kv_splits,
layer.scaling,
layer.logit_cap,
Expand Down
20 changes: 10 additions & 10 deletions python/sglang/srt/layers/attention/triton_ops/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,23 @@
# Adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py

import logging

import triton
import triton.language as tl

from sglang.srt.utils import is_hip

is_hip_ = is_hip()

logger = logging.getLogger(__name__)

# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger.warn(
"The following error message 'operation scheduled before its operands' can be ignored."
)


@triton.jit
def tanh(x):
Expand Down Expand Up @@ -166,7 +176,6 @@ def _decode_att_m_fwd(
Req_to_tokens,
B_req_idx,
B_Seqlen,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap,
Expand Down Expand Up @@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens,
B_req_idx,
B_Seqlen,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap,
Expand Down Expand Up @@ -556,7 +564,6 @@ def decode_attention_fwd_normal(
b_req_idx,
b_seq_len,
attn_logits,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap=0.0,
Expand All @@ -569,7 +576,6 @@ def decode_attention_fwd_normal(
req_to_token,
b_req_idx,
b_seq_len,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap,
Expand All @@ -586,7 +592,6 @@ def decode_attention_fwd_grouped(
b_req_idx,
b_seq_len,
attn_logits,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap=0.0,
Expand All @@ -599,7 +604,6 @@ def decode_attention_fwd_grouped(
req_to_token,
b_req_idx,
b_seq_len,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap,
Expand All @@ -614,10 +618,8 @@ def decode_attention_fwd(
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap=0.0,
Expand All @@ -636,7 +638,6 @@ def decode_attention_fwd(
b_req_idx,
b_seq_len,
attn_logits,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap,
Expand All @@ -652,7 +653,6 @@ def decode_attention_fwd(
b_req_idx,
b_seq_len,
attn_logits,
max_len_in_batch,
num_kv_splits,
sm_scale,
logit_cap,
Expand Down
6 changes: 0 additions & 6 deletions test/srt/test_triton_attention_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D):

req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")

attn_logits = torch.empty(
Expand All @@ -212,10 +211,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D):
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
Expand Down Expand Up @@ -255,7 +252,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):

req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")

attn_logits = torch.empty(
Expand All @@ -273,7 +269,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
b_req_idx,
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
Expand All @@ -293,7 +288,6 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
b_req_idx,
b_seq_len,
attn_logits1,
seq_len,
num_kv_splits,
sm_scale,
)
Expand Down
Loading