Skip to content

Commit

Permalink
[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step s…
Browse files Browse the repository at this point in the history
…cheduling (vllm-project#9038)

Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
2 people authored and Alvant committed Oct 26, 2024
1 parent 9c90c1f commit 9db54b6
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 110 deletions.
81 changes: 81 additions & 0 deletions tests/core/test_num_computed_tokens_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest

from tests.conftest import VllmRunner
from tests.core.utils import create_dummy_prompt
from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform
from vllm.sequence import SequenceGroup

MODEL = "JackFram/llama-160m"


def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup):
scheduler = engine.scheduler[0]
scheduler.add_seq_group(seq_group)


@pytest.mark.parametrize("num_scheduler_steps", [1, 8])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_num_computed_tokens_update(num_scheduler_steps: int,
enable_chunked_prefill: bool,
enforce_eager: bool):

is_multi_step = num_scheduler_steps > 1
is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill

if is_multi_step_chunked_prefill and current_platform.is_rocm():
pytest.skip("Multi-step with Chunked-Prefill does not support "
"rocm_flash_attn backend")

# Make a vllm engine
runner = VllmRunner(model_name=MODEL,
gpu_memory_utilization=0.7,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
enable_chunked_prefill=enable_chunked_prefill,
enforce_eager=enforce_eager)
engine: LLMEngine = runner.model.llm_engine

# In multi-step + chunked-prefill there is no separate single prompt step.
# What is scheduled will run for num_scheduler_steps always.
num_prompt_steps = num_scheduler_steps \
if is_multi_step_chunked_prefill else 1

num_output_tokens_list = [4, 8, 12, 15, 16, 17]

# Create sequence and add to engine
prompt_len = 10

for req_idx, num_output_tokens in enumerate(num_output_tokens_list):
seq, seq_group = create_dummy_prompt(request_id=str(req_idx),
prompt_length=prompt_len,
min_tokens=num_output_tokens,
max_tokens=num_output_tokens)
add_seq_group_to_engine(engine, seq_group)

assert seq.data.get_num_computed_tokens() == 0

for _ in range(num_prompt_steps):
# prompt steps
engine.step()

if not seq.is_finished():
prompt_num_computed_tokens = seq.data.get_num_computed_tokens()
# Test correctness of num_computed_tokens after the prompt steps
assert prompt_num_computed_tokens == \
prompt_len + num_prompt_steps - 1

decode_step_counter = 0
while not seq.is_finished():
# Test correctness of num_computed_tokens after the decode steps
assert seq.data.get_num_computed_tokens(
) == prompt_num_computed_tokens + decode_step_counter
for _ in range(num_scheduler_steps):
# decode step
engine.step()
decode_step_counter += 1

# Test correctness of num_computed_tokens after the sequence finish.
assert seq.data.get_num_computed_tokens(
) == prompt_len + num_output_tokens - 1
6 changes: 5 additions & 1 deletion tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def create_dummy_prompt(
use_beam_search: bool = False,
best_of: int = 1,
prompt_tokens: Optional[List[int]] = None,
min_tokens: int = 0,
max_tokens: int = 16,
) -> Tuple[Sequence, SequenceGroup]:
if not block_size:
block_size = prompt_length
Expand All @@ -36,7 +38,9 @@ def create_dummy_prompt(
arrival_time=time.time(),
sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of),
best_of=best_of,
max_tokens=max_tokens,
min_tokens=min_tokens),
lora_request=lora_request)

return prompt, seq_group
Expand Down
14 changes: 12 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,22 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""

assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with rocm_flash_attn yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")

# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
Expand Down
156 changes: 66 additions & 90 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,45 @@ def _process_sequence_group_outputs(

return

def _update_num_computed_tokens_for_multi_step_prefill(
self, seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata,
is_first_step_output: Optional[bool]):
"""
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
in multi-step are submitted in a single burst.
"""

assert self.scheduler_config.is_multi_step

if not seq_group_meta.is_prompt:
# num_computed_token updates for multi-step decodes happen after
# the tokens are appended to the sequence.
return

do_update: bool = False
if self.scheduler_config.chunked_prefill_enabled:
# In multi-step + chunked-prefill case, the prompt sequences
# that are scheduled are fully processed in the first step.
do_update = is_first_step_output is None or is_first_step_output
else:
# Normal multi-step decoding case. In this case prompt-sequences
# are actually single-stepped. Always update in this case.
assert seq_group.state.num_steps == 1
do_update = True

if do_update:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)

def _process_model_outputs(self,
ctx: SchedulerContext,
request_id: Optional[str] = None) -> None:
Expand All @@ -1019,64 +1058,6 @@ def _process_model_outputs(self,
request_id: If provided, then only this request is going to be processed
"""

def update_prefill_num_computed_tokens(
seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
is_first_step_output: Optional[bool]) -> None:
"""
When multi-step and chunked-prefill are enabled together, the
prefill sequence scheduled for multi-step execution turn into
decodes in the first step itself. This function accounts
for that conversion.
seq_group: SequenceGroup - A prefill seq_group
seq_group_meta: SequenceGroupMetadata - Metadata of the given
prefill seq_group
num_outputs: int - number of output tokens being processed for the
given seq_group
is_first_step_output: Optional[bool] -
If multi-step is enabled and num_outputs is 1, this value
indicates if this outputs belongs to the first step in the
multi-step.
If multi-step is enabled and num_outputs > 1, this value
must be None, as num_outputs > 1 indicates that outputs from
all the steps in multi-step are submitted in a single burst.
When multi-step is disabled, this value is always True.
"""

assert seq_group_meta.is_prompt

token_chunk_size = seq_group_meta.token_chunk_size

if num_outputs == 1:
assert is_first_step_output is not None

if seq_group_meta.state.num_steps == 1:
assert is_first_step_output is True
seq_group.update_num_computed_tokens(token_chunk_size)
return

# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
if is_first_step_output is True:
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)
return

assert is_first_step_output is None

# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill. Outputs from all the steps are
# submitted in a single burst.
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
assert num_outputs == seq_group_meta.state.num_steps, \
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)

now = time.time()

if len(ctx.output_queue) == 0:
Expand Down Expand Up @@ -1137,7 +1118,7 @@ def update_prefill_num_computed_tokens(
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

seq_group = scheduled_seq_group.seq_group
seq_group: SequenceGroup = scheduled_seq_group.seq_group

if seq_group.is_finished():
finished_before.append(i)
Expand All @@ -1148,14 +1129,14 @@ def update_prefill_num_computed_tokens(
else:
output = [outputs_by_sequence_group[0][i]]

if not is_async and seq_group_meta.is_prompt:
# Updates for all decodes happen when we actually append the
# token ids to the seq in process_outputs.
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
len(output),
is_first_step_output)
elif not is_async:
seq_group.update_num_computed_tokens(1)
if not is_async:
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)

if outputs:
for o in outputs:
Expand All @@ -1179,16 +1160,8 @@ def update_prefill_num_computed_tokens(
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
output_token_num = self.output_processor.process_outputs(
self.output_processor.process_outputs(
seq_group, output, is_async)
if self.speculative_config:
# We -1 here because we always
# (w/o speculative decoding) add the number of
# computed tokens by one in the decoding phase.
# Therefore, we remove that one token that
# is already added.
seq_group.update_num_computed_tokens(output_token_num -
1)

if seq_group.is_finished():
finished_now.append(i)
Expand Down Expand Up @@ -1297,20 +1270,15 @@ def _advance_to_next_step(
if seq_group.is_finished():
continue

if seq_group_metadata.is_prompt:
if self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled:
# Prompts are scheduled in multi-step only when
# chunking is enabled. These prompts turn into
# decodes after the very first step. Therefore,
# we skip the update to the num_computed_tokens
# here.
seq_group.update_num_computed_tokens(1)
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
seq_group.update_num_computed_tokens(1)
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)

if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
Expand All @@ -1320,7 +1288,15 @@ def _advance_to_next_step(

assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)

if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_prefill_append:
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs)

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
Expand Down
8 changes: 2 additions & 6 deletions vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from typing import Callable, List

from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
Expand Down Expand Up @@ -58,14 +58,10 @@ def create_output_processor(
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool) -> Optional[int]:
is_async: bool) -> None:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
Return the number of new tokens generated in the sequence group.
The returned value is optional because it is only used for
speculative decoding mqa scorer.
"""
pass

Expand Down
Loading

0 comments on commit 9db54b6

Please sign in to comment.