Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamic Spec Decoding] Auto-disable by the running queue size #4592

Merged
merged 11 commits into from
May 8, 2024
13 changes: 9 additions & 4 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
def test_correct_output_format(which_tokens_accepted: str,
disable_bonus_tokens: bool, seed: int,
device: str):
"""Verify the output has correct format given predetermined accepted matrix.
"""
Expand Down Expand Up @@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size=(batch_size, 1),
dtype=torch.int64)

rejection_sampler = RejectionSampler()
rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens)
rejection_sampler.init_gpu_tensors(rank=0)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
Expand All @@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids,
)

# Bonus tokens are currently disabled. Verify they're set to -1.
expected_bonus_token_ids = bonus_token_ids.clone()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
if disable_bonus_tokens:
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1

if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
Expand Down
34 changes: 34 additions & 0 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 1,
comaniac marked this conversation as resolved.
Show resolved Hide resolved
comaniac marked this conversation as resolved.
Show resolved Hide resolved
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1])
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when all sequences disable speculation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
77 changes: 77 additions & 0 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from unittest.mock import MagicMock

import pytest
import torch

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

from .utils import create_batch, mock_worker


@pytest.mark.parametrize('queue_size', [2, 4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3

draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)

seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)
comaniac marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)

# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens

draft_worker.sampler_output.side_effect = ValueError(exception_secret)

proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)

if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert proposals.proposal_lens.tolist() == [0] * batch_size
29 changes: 24 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def maybe_create_spec_config(
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
) -> Optional["SpeculativeConfig"]:
Expand Down Expand Up @@ -720,6 +721,9 @@ def maybe_create_spec_config(
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
Expand All @@ -730,7 +734,7 @@ def maybe_create_spec_config(
the necessary conditions are met, else None.
"""

if (speculative_model is None and num_speculative_tokens is None):
if speculative_model is None and num_speculative_tokens is None:
return None

if speculative_model is not None and num_speculative_tokens is None:
Expand All @@ -739,6 +743,12 @@ def maybe_create_spec_config(
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")

assert (speculative_model is not None
and num_speculative_tokens is not None)

Expand Down Expand Up @@ -807,6 +817,7 @@ def maybe_create_spec_config(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
)
Expand Down Expand Up @@ -876,8 +887,9 @@ def __init__(
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
ngram_prompt_lookup_max: int,
ngram_prompt_lookup_min: int,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
):
"""Create a SpeculativeConfig object.

Expand All @@ -886,12 +898,19 @@ def __init__(
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0

self._verify_args()

Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class EngineArgs:
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None

Expand Down Expand Up @@ -467,6 +468,13 @@ def add_cli_args(
'draft model. Sequences over this length will skip '
'speculation.')

parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')

parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
Expand Down Expand Up @@ -547,6 +555,8 @@ def create_engine_config(self, ) -> EngineConfig:
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
Expand Down
2 changes: 2 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _init_spec_worker(self):
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)

assert self.parallel_config.world_size == 1, (
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf.
"""

def __init__(self, strict_mode: bool = False):
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Create a rejection sampler.

Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode

# NOTE: A "bonus token" is accepted iff all proposal tokens are
Expand Down Expand Up @@ -312,7 +318,8 @@ def _create_output(
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1

# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
Expand Down
6 changes: 6 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,12 @@ def __init__(
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample

# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None

if self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len()
Expand Down
Loading
Loading