Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core]: Option To Use Prompt Token Ids Inside Logits Processor #4985

Merged
merged 15 commits into from
May 23, 2024
Prev Previous commit
Next Next commit
more clear code
kezouke committed May 22, 2024
commit 5ed4398e82fa9a4c07db1f9e56a813bed9096137
32 changes: 13 additions & 19 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
@@ -95,31 +95,25 @@ def _apply_logits_processors(
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
use_prompt_tokens_seq = \
sampling_params.use_prompt_tokens
use_prompt_tokens_seq = sampling_params.use_prompt_tokens
if logits_processors:
found_logits_processors = True
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]

token_ids = seq_group.seq_data[seq_id].output_token_ids

token_ids_seq = []
for use_prompt_tokens in use_prompt_tokens_seq:
if use_prompt_tokens:
# The i-th logit processor need prompt tokens ids
token_ids_seq.append(
seq_group.seq_data[seq_id].prompt_token_ids +
token_ids)
else:
# The i-th logit processor need only generated
# tokens ids
token_ids_seq.append(token_ids)

for l_p_idx, logits_processor in enumerate(logits_processors):
logits_row = logits_processor(token_ids_seq[l_p_idx],
logits_row)
generated_token_ids = seq_group.seq_data[
seq_id].output_token_ids
prompt_token_ids = seq_group.seq_data[seq_id].prompt_token_ids
token_ids_seq = [
prompt_token_ids + generated_token_ids
if use_prompt_tokens else generated_token_ids
for use_prompt_tokens in use_prompt_tokens_seq
]

for logits_processor, token_ids in zip(logits_processors,
token_ids_seq):
logits_row = logits_processor(token_ids, logits_row)

logits[logits_row_idx] = logits_row

13 changes: 6 additions & 7 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -96,8 +96,8 @@ class SamplingParams:
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
use_prompt_tokens: Boolean list of whether to add
prompt tokens to the 'token_ids' argument of the logit processor.
use_prompt_tokens: List of booleans indicating whether to include prompt
tokens IDs in the logits processors functions.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
@@ -180,14 +180,13 @@ def __init__(
else:
self.output_text_buffer_length = 0

if logits_processors and use_prompt_tokens is None:
if logits_processors is not None and use_prompt_tokens is None:
assert self.logits_processors is not None
# Not use prompt tokens in each logit processor
self.use_prompt_tokens: Optional[List[bool]] \
= [False] * len(self.logits_processors)
else:
self.use_prompt_tokens = \
use_prompt_tokens
self.use_prompt_tokens = use_prompt_tokens

self._verify_args()
if self.use_beam_search:
@@ -260,8 +259,8 @@ def _verify_args(self) -> None:
assert self.use_prompt_tokens is not None
if len(self.logits_processors) != \
len(self.use_prompt_tokens):
raise ValueError("logits_processors_use_prompt_tokens must"
" be the same length as logits_processors")
raise ValueError("use_prompt_tokens must be the "
"same length as logits_processors")

def _verify_beam_search(self) -> None:
if self.best_of == 1: