Skip to content

Commit

Permalink
[Bugfix] Last token measurement fix (#11376)
Browse files Browse the repository at this point in the history
Signed-off-by: rajveerb <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
rajveerb and ywang96 authored Dec 28, 2024
1 parent df04dff commit b5cbe8e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
8 changes: 6 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,8 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
Expand Down Expand Up @@ -1166,6 +1168,8 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
Expand Down Expand Up @@ -1686,15 +1690,15 @@ def _get_stats(self,
# If the seq_group just finished the prefill state
# get TTFT.
if not seq_group.is_prefill():
latency = seq_group.get_last_latency(now)
latency = seq_group.get_last_token_latency()
time_to_first_tokens_iter.append(latency)

# One generation token per finished prefill.
num_generation_tokens_from_prefill_groups += (
seq_group.num_seqs())
else:
# TPOTs.
latency = seq_group.get_last_latency(now)
latency = seq_group.get_last_token_latency()
time_per_output_tokens_iter.append(latency)
if seq_group.state.current_step == 0:
# For async_output_proc, the do_log_stats()
Expand Down
24 changes: 14 additions & 10 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def __init__(
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.last_token_latency = 0.0
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
Expand Down Expand Up @@ -762,18 +763,21 @@ def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
self.init_multi_step(num_steps=num_lookahead_slots + 1)

def get_last_latency(self, now: float) -> float:
def set_last_token_time(self, now: float) -> None:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
raise ValueError(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase.")

# Otherwise return token latency.
latency = now - self.metrics.last_token_time
# If still in prefill phase, assertion fails.
assert not self.is_prefill(), (
"seq_group.set_last_token_time() should not be called "
"if the seq_group is in prefill phase.")
self.last_token_latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now
return latency

def get_last_token_latency(self) -> float:
"""Returns the latency of the last token."""
assert not self.is_prefill(), (
"seq_group.get_last_token_latency() should not be called "
"if the seq_group is in prefill phase.")
return self.last_token_latency

def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings."""
Expand Down

0 comments on commit b5cbe8e

Please sign in to comment.