diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 474d820cab00e..f67ddbc9d6326 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -51,11 +51,14 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): assert seq_group_metadata_list[ 0].num_speculative_tokens == expected_num_spec_tokens + draft_worker.sampler_output.side_effect = ValueError(exception_secret) + proposer = Top1Proposer( worker=draft_worker, device='cpu', # not used vocab_size=100, # not used - max_proposal_len=10, + # Must be long enough to avoid being skipped due to length. + max_proposal_len=1024, ) if queue_size < disable_at_queue_size: @@ -66,9 +69,9 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): num_lookahead_slots=k), ) else: # Should not execute the draft model because spec decode is disabled - # for all requests. + # for all requests. Accordingly, the proposal length should be 0. proposals = proposer.get_proposals( execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k), ) - assert proposals.proposal_lens.tolist() == [k] * batch_size + assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 61097e9ab94cf..afe2afbedf9b6 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -79,11 +79,8 @@ def create_worker( proposer_worker, scorer_worker, disable_at_queue_size=disable_at_queue_size, - # TODO(cade) disable strict mode for speedup. rejection_sampler=RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens, - strict_mode=not envs.VLLM_DISABLE_REJECT_SAMPLING_STRICT_MODE), - ) + disable_bonus_tokens=disable_bonus_tokens, )) def __init__( self,