-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dynamic Spec Decoding] Minor fix for disabling speculative decoding #5000
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
@@ -276,7 +276,8 @@ def execute_model( | |||
# If no spec tokens, call the proposer and scorer workers normally. | |||
# Used for prefill. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refine the comment to include auto disable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let's update the comment before merging
"ngram_prompt_lookup_max": 3, | ||
"speculative_disable_by_batch_size": 4 | ||
}]) | ||
@pytest.mark.parametrize("batch_size", [1, 2, 5, 8]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: suggest only [1, 5] to reduce test time
Looks like the CI failure is unrelated and we should just merge this. cc @simon-mo |
In the current implementation, even if
running_queue_size >= speculative_disable_by_batch_size
, it will still go through the speculative decoding logic, which includes get_spec_proposals with k=0, score_proposal, rejection sampling and create sampler output. The flow introduces extra overhead (especially rejection sampling), which makes disabling speculative decoding slower than 'real' without speculative decoding.To fix this, we can just reuse the
_run_no_spec
to avoid touching the sd flow at all. Also add a test to check the correctness.Concretely, for a batch size of 8 with 128 output tokens, TP=4, for LLama3-70B, the batch latency is
Here, without SD means not using sd flag at all. Disable SD means using SD flag, but set
speculative_disable_by_batch_size
smaller than batch size to disable speculative decoding.After the fix, we are still slower than the original case, this is caused by broadcasting control flow, which will be fixed in future PRs.