-
Notifications
You must be signed in to change notification settings - Fork 710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize Triton decoding kernel for long context #2394
Conversation
@@ -705,10 +650,10 @@ def decode_attention_fwd( | |||
o, | |||
req_to_token, | |||
b_req_idx, | |||
b_start_loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this in the func signature of decode_attention_fwd
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. max_len_in_batch
and triton_attention_reduce_in_fp32
may also need to be removed.
forward_batch.batch_size, | ||
self.num_head, | ||
self.num_kv_splits, | ||
self.v_head_dim + 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this, we do not need to reduce the cuda graph max bs for deepseek models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me verify it.
Motivation
As mentioned in #2271, the original triton decoding kernel has significant performance degradation on long context. We refactored the kernel and adapted the flash decoding implementation from lightllm. Currently, the long context speed decay has been alleviated a lot.
Benchmark
Tested for input 128, output 2048.
Triton (this PR) num_kv_splits=8: 150->138
We can increase the
--triton-attention-num-kv-splits
to get better performance on long context.Triton (this PR) num_kv_splits=16: 150->144
Triton (main branch): 147->126
Flashinfer: 143->143