From 8136fa4561d201012bf1ed02c8f5200b918bc554 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 26 Nov 2024 23:40:04 +0000 Subject: [PATCH 01/39] Add options for min_tokens/repetition etc penalties to V1 sampler Signed-off-by: Sourashis Roy --- vllm/v1/sample/metadata.py | 5 ++++ vllm/v1/worker/gpu_model_runner.py | 48 ++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9ef36f2e6b212..18501fb9ae5f2 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -19,3 +19,8 @@ class SamplingMetadata: generators: Dict[int, torch.Generator] max_num_logprobs: int + + min_tokens: int = 0 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + repetition_penalty: float = 1.0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 02f9498142bb7..be33dd71c7097 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -686,6 +686,38 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() + self.presence_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = \ + self.presence_penalties_cpu_tensor.numpy() + + self.frequency_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + + self.repetition_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu = \ + self.repetition_penalties_cpu_tensor.numpy() + # req_index -> generator self.generators: Dict[int, torch.Generator] = {} @@ -732,6 +764,22 @@ def add_request( if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) + self.presence_penalties_cpu[req_index] = \ + sampling_params.presence_penalty + if sampling_params.presence_penalty > 0: + self.top_k_reqs.add(req_id) + + + self.frequency_penalties_cpu[req_index] = \ + sampling_params.frequency_penalty + if sampling_params.frequency_penalty > 0: + self.top_k_reqs.add(req_id) + + self.repetition_penalties_cpu[req_index] = \ + sampling_params.repetition_penalty + if sampling_params.repetition_penalty > 0: + self.top_k_reqs.add(req_id) + self.generators[req_index] = request.generator num_logprobs = sampling_params.logprobs From 06d3247a3c5192f7ed945b6c8964d5b7b73a85a1 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sun, 1 Dec 2024 07:14:21 +0000 Subject: [PATCH 02/39] Fixes --- vllm/utils.py | 40 ++++++++++++ vllm/v1/sample/metadata.py | 13 ++-- vllm/v1/sample/sampler.py | 89 ++++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 100 +++++++++++++++-------------- 4 files changed, 186 insertions(+), 56 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index bec876d983701..ecca7001eaef4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1639,3 +1639,43 @@ def resolve_obj_by_qualname(qualname: str) -> Any: module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, obj_name) + + +def get_token_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def apply_sampling_penalties( + logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) + repetition_penalties[~(prompt_mask | output_mask)] = 1.0 + logits = torch.where(logits > 0, logits / repetition_penalties, + logits * repetition_penalties) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + return logits diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 18501fb9ae5f2..529d09f2ecadd 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict +from typing import Dict, List import torch @@ -20,7 +20,10 @@ class SamplingMetadata: max_num_logprobs: int - min_tokens: int = 0 - presence_penalty: float = 0.0 - frequency_penalty: float = 0.0 - repetition_penalty: float = 1.0 + output_token_ids: List[List[int]] + prompt_token_ids: List[List[int]] + frequency_penalties: List[float] + presence_penalties: List[float] + repetition_penalties: List[float] + min_tokens: List[int] + stop_token_ids: List[List[int]] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 927f274541c4d..e44aae4a51991 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,9 +1,10 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict +from typing import Dict, List, Tuple import torch import torch.nn as nn +from vllm.utils import apply_sampling_penalties, make_tensor_with_pad from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -19,7 +20,14 @@ def forward( ) -> SamplerOutput: logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_top_k_top_p(logits, sampling_metadata) - + _apply_min_token_penalties(logits, sampling_metadata.output_token_ids, + sampling_metadata.stop_token_ids, + sampling_metadata.min_tokens) + _apply_penalties(logits, sampling_metadata.prompt_token_ids, + sampling_metadata.output_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties) probs = self.get_probs(logits) sampled = self.sample(probs, sampling_metadata) # Use int32 to reduce the tensor size. @@ -156,3 +164,80 @@ def _apply_top_k_top_p( # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits + + +def _apply_min_token_penalties(logits: torch.Tensor, + output_token_ids: List[List[int]], + stop_token_ids: List[List[int]], + min_tokens: List[int]): + # Compute min_tokens_logits_to_penalize + min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] + for index, min_token in enumerate(min_tokens): + if (min_token > 0 and len(output_token_ids[index]) < min_token): + for stop_token_id in stop_token_ids: + min_tokens_logits_to_penalize.append((index, stop_token_id)) + if min_tokens_logits_to_penalize: + logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") + + +def _apply_penalties(logits: torch.Tensor, prompt_token_ids: List[List[int]], + output_token_ids: List[List[int]], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float]): + apply_penalties = any(p != 0.0 for p in presence_penalties) or any( + f != 0.0 + for f in frequency_penalties) or any(r != 1.0 + for r in repetition_penalties) + if apply_penalties: + # Convert to tensors + _, vocab_size = logits.shape + (prompt_tokens_t, output_tokens_t, frequency_penalties_t, + presence_penalties_t, repetition_penalties_t) = \ + _convert_to_tensors( + prompt_token_ids, output_token_ids, frequency_penalties, + presence_penalties, repetition_penalties, vocab_size, + logits.device) + return apply_sampling_penalties(logits, prompt_tokens_t, + output_tokens_t, presence_penalties_t, + frequency_penalties_t, + repetition_penalties_t) + + +def _convert_to_tensors(prompt_token_ids: List[List[int]], + output_token_ids: List[List[int]], + frequency_penalties: List[float], + presence_penalties: List[float], + repetition_penalties: List[float], vocab_size: int, + device: torch.device) -> Tuple[torch.Tensor, ...]: + prompt_tokens_tensor = make_tensor_with_pad( + prompt_token_ids, + vocab_size, + device=device, + dtype=torch.int64, + ) + output_tokens_tensor = make_tensor_with_pad( + output_token_ids, + vocab_size, + device=device, + dtype=torch.int64, + ) + frequency_penalties_tensor = torch.tensor( + frequency_penalties, + device=device, + dtype=torch.float, + ) + presence_penalties_tensor = torch.tensor( + presence_penalties, + device=device, + dtype=torch.float, + ) + repetition_penalties_tensor = torch.tensor( + repetition_penalties, + device=device, + dtype=torch.float, + ) + + return (prompt_tokens_tensor, output_tokens_tensor, + frequency_penalties_tensor, presence_penalties_tensor, + repetition_penalties_tensor) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index be33dd71c7097..f4c1dc7657628 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -335,7 +335,8 @@ def _prepare_sampling( or scheduler_output.scheduled_resumed_reqs): skip_copy = False # Create the sampling metadata. - sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + sampling_metadata = self.input_batch.make_sampling_metadata( + self.requests, skip_copy) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -609,7 +610,6 @@ class CachedRequestState: mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: List[int] num_computed_tokens: int output_token_ids: List[int] @@ -618,6 +618,22 @@ class CachedRequestState: def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) + @property + def stop_token_ids(self) -> Optional[List[int]]: + return self.sampling_params.stop_token_ids + + @property + def prompt_tokens_mask(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + @property + def output_tokens_mask(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + @property + def output_tokens_bin_counts(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + class InputBatch: @@ -686,37 +702,8 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = \ - self.presence_penalties_cpu_tensor.numpy() - - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() - - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self.prompt_masks = Dict[int, torch.Tensor] + self.output_masks = Dict[int, torch.Tensor] # req_index -> generator self.generators: Dict[int, torch.Generator] = {} @@ -764,22 +751,6 @@ def add_request( if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) - self.presence_penalties_cpu[req_index] = \ - sampling_params.presence_penalty - if sampling_params.presence_penalty > 0: - self.top_k_reqs.add(req_id) - - - self.frequency_penalties_cpu[req_index] = \ - sampling_params.frequency_penalty - if sampling_params.frequency_penalty > 0: - self.top_k_reqs.add(req_id) - - self.repetition_penalties_cpu[req_index] = \ - sampling_params.repetition_penalty - if sampling_params.repetition_penalty > 0: - self.top_k_reqs.add(req_id) - self.generators[req_index] = request.generator num_logprobs = sampling_params.logprobs @@ -850,6 +821,7 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator @@ -859,6 +831,7 @@ def condense(self, empty_req_indices: List[int]) -> None: def make_sampling_metadata( self, + requests: Dict[str, CachedRequestState], skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: @@ -868,6 +841,28 @@ def make_sampling_metadata( self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k[:self.num_reqs].copy_( self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + + output_token_ids: List[List[int]] = [] + prompt_token_ids: List[List[int]] = [] + frequency_penalties: List[float] = [] + presence_penalties: List[float] = [] + repetition_penalties: List[float] = [] + min_tokens: List[int] = [] + stop_token_ids: List[List[int]] = [] + + for req_id in self.req_ids[:self.num_reqs]: + assert req_id is not None + request = requests[req_id] + output_token_ids.append(request.output_token_ids) + prompt_token_ids.append(request.prompt_token_ids) + frequency_penalties.append( + request.sampling_params.frequency_penalty) + presence_penalties.append(request.sampling_params.presence_penalty) + repetition_penalties.append( + request.sampling_params.repetition_penalty) + min_tokens.append(request.sampling_params.min_tokens) + stop_token_ids.append(request.sampling_params.stop_token_ids) + return SamplingMetadata( temperature=self.temperature[:self.num_reqs], all_greedy=self.all_greedy, @@ -878,6 +873,13 @@ def make_sampling_metadata( no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + frequency_penalties=frequency_penalties, + presence_penalties=presence_penalties, + repetition_penalties=repetition_penalties, + min_tokens=min_tokens, + stop_token_ids=stop_token_ids, ) @property From ca0313a114e98bc8a60e262ec515e30f44d6d8c5 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 2 Dec 2024 17:06:54 +0000 Subject: [PATCH 03/39] Add tests --- vllm/model_executor/layers/sampler.py | 52 ++++----------------------- vllm/utils.py | 6 ++-- vllm/v1/sample/metadata.py | 4 +-- vllm/v1/sample/sampler.py | 35 +++++++++++------- vllm/v1/worker/gpu_model_runner.py | 4 +-- 5 files changed, 35 insertions(+), 66 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c10efefea5471..5c3ccd74ad8f8 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -19,6 +19,7 @@ CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics +from vllm.utils import apply_sampling_penalties if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -258,11 +259,12 @@ def forward( # Apply presence and frequency penalties. if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) + logits = apply_sampling_penalties( + logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. @@ -336,23 +338,6 @@ def _should_modify_greedy_probs_inplace(self) -> bool: return self.should_modify_greedy_probs_inplace -def _get_bin_counts_and_mask( - tokens: torch.Tensor, - vocab_size: int, - num_seqs: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - # Compute the bin counts for the tokens. - # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) - bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) - bin_counts = bin_counts[:, :vocab_size] - mask = bin_counts > 0 - - return bin_counts, mask - - def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -400,29 +385,6 @@ def _apply_min_tokens_penalty( return logits -def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: - num_seqs, vocab_size = logits.shape - _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, - num_seqs) - output_bin_counts, output_mask = _get_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - - repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~(prompt_mask | output_mask)] = 1.0 - logits = torch.where(logits > 0, logits / repetition_penalties, - logits * repetition_penalties) - - # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask - return logits - - def _apply_top_k_top_p( logits: torch.Tensor, p: torch.Tensor, diff --git a/vllm/utils.py b/vllm/utils.py index ecca7001eaef4..14265ff349a9e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1668,12 +1668,10 @@ def apply_sampling_penalties( vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 - logits = torch.where(logits > 0, logits / repetition_penalties, - logits * repetition_penalties) - + logits[logits > 0] /= repetition_penalties[logits > 0] + logits[logits <= 0] *= repetition_penalties[logits <= 0] # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 529d09f2ecadd..0503012088928 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Set import torch @@ -26,4 +26,4 @@ class SamplingMetadata: presence_penalties: List[float] repetition_penalties: List[float] min_tokens: List[int] - stop_token_ids: List[List[int]] + stop_token_ids: List[Set[int]] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index e44aae4a51991..c096c1ff3abb2 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,5 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Set import torch import torch.nn as nn @@ -18,16 +18,17 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - logits = self.apply_temperature(logits, sampling_metadata.temperature) - logits = self.apply_top_k_top_p(logits, sampling_metadata) - _apply_min_token_penalties(logits, sampling_metadata.output_token_ids, + _apply_min_token_penalties(logits, + sampling_metadata.output_token_ids, sampling_metadata.stop_token_ids, sampling_metadata.min_tokens) _apply_penalties(logits, sampling_metadata.prompt_token_ids, sampling_metadata.output_token_ids, sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties) + sampling_metadata.repetition_penalties) + logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_top_k_top_p(logits, sampling_metadata) probs = self.get_probs(logits) sampled = self.sample(probs, sampling_metadata) # Use int32 to reduce the tensor size. @@ -168,29 +169,34 @@ def _apply_top_k_top_p( def _apply_min_token_penalties(logits: torch.Tensor, output_token_ids: List[List[int]], - stop_token_ids: List[List[int]], - min_tokens: List[int]): - # Compute min_tokens_logits_to_penalize + stop_token_ids: List[Set[int]], + min_tokens: List[int]) -> torch.Tensor: + """ + Applies minimum token penalty by setting the logits of the stop tokens + to -inf. + """ min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] for index, min_token in enumerate(min_tokens): if (min_token > 0 and len(output_token_ids[index]) < min_token): - for stop_token_id in stop_token_ids: + for stop_token_id in stop_token_ids[index]: min_tokens_logits_to_penalize.append((index, stop_token_id)) if min_tokens_logits_to_penalize: logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") - + return logits def _apply_penalties(logits: torch.Tensor, prompt_token_ids: List[List[int]], output_token_ids: List[List[int]], presence_penalties: List[float], frequency_penalties: List[float], - repetition_penalties: List[float]): + repetition_penalties: List[float]) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ apply_penalties = any(p != 0.0 for p in presence_penalties) or any( f != 0.0 for f in frequency_penalties) or any(r != 1.0 for r in repetition_penalties) if apply_penalties: - # Convert to tensors _, vocab_size = logits.shape (prompt_tokens_t, output_tokens_t, frequency_penalties_t, presence_penalties_t, repetition_penalties_t) = \ @@ -202,6 +208,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_token_ids: List[List[int]], output_tokens_t, presence_penalties_t, frequency_penalties_t, repetition_penalties_t) + return logits def _convert_to_tensors(prompt_token_ids: List[List[int]], @@ -210,6 +217,9 @@ def _convert_to_tensors(prompt_token_ids: List[List[int]], presence_penalties: List[float], repetition_penalties: List[float], vocab_size: int, device: torch.device) -> Tuple[torch.Tensor, ...]: + """ + Convert the different list data structures to tensors. + """ prompt_tokens_tensor = make_tensor_with_pad( prompt_token_ids, vocab_size, @@ -237,7 +247,6 @@ def _convert_to_tensors(prompt_token_ids: List[List[int]], device=device, dtype=torch.float, ) - return (prompt_tokens_tensor, output_tokens_tensor, frequency_penalties_tensor, presence_penalties_tensor, repetition_penalties_tensor) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f4c1dc7657628..51f8ac02e2f7f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -848,7 +848,7 @@ def make_sampling_metadata( presence_penalties: List[float] = [] repetition_penalties: List[float] = [] min_tokens: List[int] = [] - stop_token_ids: List[List[int]] = [] + stop_token_ids: List[set[int]] = [] for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None @@ -861,7 +861,7 @@ def make_sampling_metadata( repetition_penalties.append( request.sampling_params.repetition_penalty) min_tokens.append(request.sampling_params.min_tokens) - stop_token_ids.append(request.sampling_params.stop_token_ids) + stop_token_ids.append(request.sampling_params.all_stop_token_ids) return SamplingMetadata( temperature=self.temperature[:self.num_reqs], From e19f99bedac3d9d39a4e0bfe77182f1d3d4d289b Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 2 Dec 2024 17:21:26 +0000 Subject: [PATCH 04/39] Add tests --- tests/v1/sample/test_sampler.py | 176 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 21 ---- 2 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 tests/v1/sample/test_sampler.py diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py new file mode 100644 index 0000000000000..96c64b857794c --- /dev/null +++ b/tests/v1/sample/test_sampler.py @@ -0,0 +1,176 @@ +import pytest +import torch +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler +from typing import List, Set, Tuple +import numpy as np + +VOCAB_SIZE = 1024 +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] +MAX_NUM_PROMPT_TOKENS = 64 + + +def _create_fake_logits( + batch_size: int, vocab_size:int +) -> torch.Tensor: + fake_logits = torch.full((batch_size, vocab_size), + 1e-2, + dtype=torch.float) + return fake_logits + +def _create_default_sampling_metadata( + num_output_tokens: int, batch_size: int, + vocab_size: int, +) -> SamplingMetadata: + output_token_ids:List[List[int]] = [] + prompt_token_ids:List[List[int]] = [] + for _ in range(batch_size): + output_token_ids.append( + np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + prompt_token_ids.append( + np.random.randint(0, vocab_size, + size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)).tolist()) + fake_sampling_metadata = SamplingMetadata( + temperature=torch.full((batch_size,), 0.0), + all_greedy=True, + all_random=False, + top_p=torch.empty(batch_size,), + top_k=torch.empty(batch_size,), + no_top_p=True, + no_top_k=True, + generators={}, + max_num_logprobs=VOCAB_SIZE, + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + frequency_penalties=[0.0 for _ in range(batch_size)], + presence_penalties=[0.0 for _ in range(batch_size)], + repetition_penalties=[1.0 for _ in range(batch_size)], + min_tokens=[], + stop_token_ids=[], + ) + return fake_sampling_metadata + +def _create_min_token_penalty_dataset( + num_output_tokens: int, + batch_size: int, + vocab_size: int, + batch_indices_for_min_token_penalty:List[int] +) -> Tuple[List[int], List[Set[int]]]: + stop_token_ids:List[Set[int]] = [] + min_tokens: List[int]=[] + for index in range(batch_size): + if index in batch_indices_for_min_token_penalty: + min_tokens.append( + np.random.randint(num_output_tokens + 1, 2 * num_output_tokens)) + stop_token_ids.append( + set(np.random.randint(0, vocab_size - 1) for _ in range( + np.random.randint(0, vocab_size)))) + + else: + min_tokens.append(np.random.randint(0, num_output_tokens)) + stop_token_ids.append(set()) + return (min_tokens, stop_token_ids) + +def _create_weighted_output_token_list( + batch_size: int, + vocab_size: int +) -> Tuple[List[List[int]], List[List[int]]]: + output_token_ids : List[List[int]] = [] + sorted_token_ids_in_output : List[List[int]] = [] + for _ in range(batch_size): + distinct_token_ids = np.random.choice(vocab_size, size=np.random.randint(1, 10), replace=False).tolist() + sorted_token_ids_in_output.append(distinct_token_ids) + output_token_ids_for_batch = [] + for index, token_id in enumerate(distinct_token_ids): + output_token_ids_for_batch.extend([token_id for _ in range(index+1) ]) + output_token_ids.append(output_token_ids_for_batch) + return (output_token_ids, sorted_token_ids_in_output) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +def test_sampler_min_tokens_penalty(device: str, batch_size: int): + """ + Tests that if the number of output tokens is less than + SamplingParams.min_tokens then we will set the logits for + the stop token ids to -inf. + """ + torch.set_default_device(device) + NUM_OUTPUT_TOKENS = 20 + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata= _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) + batch_indices_for_min_token_penalty = np.random.randint( + 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() + min_tokens, stop_token_ids = _create_min_token_penalty_dataset( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty) + sampling_metadata.min_tokens = min_tokens + sampling_metadata.stop_token_ids = stop_token_ids + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + for vocab in range(VOCAB_SIZE): + logprob_index = torch.where( + sampler_output.logprob_token_ids[batch_idx] == vocab)[0].item() + if vocab in stop_token_ids[batch_idx]: + assert sampler_output.logprobs[batch_idx][logprob_index] == -float("inf") + else: + assert sampler_output.logprobs[batch_idx][logprob_index] != -float("inf") + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +def test_sampler_presence_penalty(device: str, batch_size: int): + torch.set_default_device(device) + NUM_OUTPUT_TOKENS = 20 + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata= _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) + output_token_ids = sampling_metadata.output_token_ids + sampling_metadata.presence_penalties = [2.0 for _ in range(batch_size)] + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + logprob_for_output_token = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] + logprob_for_non_output_token = sampler_output.logprobs[batch_idx][0] + assert logprob_for_non_output_token > logprob_for_output_token + for vocab in range(VOCAB_SIZE): + logprob_index = torch.where( + sampler_output.logprob_token_ids[batch_idx] == vocab)[0].item() + if vocab in output_token_ids[batch_idx]: + assert torch.isclose( + sampler_output.logprobs[batch_idx][logprob_index], + logprob_for_output_token) + else: + assert torch.isclose( + sampler_output.logprobs[batch_idx][logprob_index], + logprob_for_non_output_token) + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +def test_sampler_frequency_penalty(device: str, batch_size: int): + """ + Test to verify that if fre + """ + torch.set_default_device(device) + NUM_OUTPUT_TOKENS = 20 + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata= _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) + sampling_metadata.frequency_penalties = [2.0 for _ in range(batch_size)] + output_token_ids, sorted_token_ids_in_output = \ + _create_weighted_output_token_list(batch_size, VOCAB_SIZE) + sampling_metadata.output_token_ids=output_token_ids + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] + token_ids_in_output = sorted_token_ids_in_output[batch_idx] + assert not torch.isin( + logprobs_token_ids[ : -len(token_ids_in_output)], + torch.tensor(token_ids_in_output)).any(), "Some values in the tensor are in the list" + assert logprobs_token_ids[-len(token_ids_in_output):].tolist() == token_ids_in_output, \ + "The tensor values are not in the same order as the list!" + + diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 51f8ac02e2f7f..30d73177a77c4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -618,23 +618,6 @@ class CachedRequestState: def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) - @property - def stop_token_ids(self) -> Optional[List[int]]: - return self.sampling_params.stop_token_ids - - @property - def prompt_tokens_mask(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - @property - def output_tokens_mask(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - @property - def output_tokens_bin_counts(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - class InputBatch: def __init__( @@ -702,9 +685,6 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() - self.prompt_masks = Dict[int, torch.Tensor] - self.output_masks = Dict[int, torch.Tensor] - # req_index -> generator self.generators: Dict[int, torch.Generator] = {} @@ -821,7 +801,6 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator From 40f4ce28c1cce5eef1dbae0b4b25a4b83c3fee1a Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 2 Dec 2024 17:23:36 +0000 Subject: [PATCH 05/39] Fix format --- vllm/v1/worker/gpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 30d73177a77c4..e981514400368 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -610,6 +610,7 @@ class CachedRequestState: mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams generator: Optional[torch.Generator] + block_ids: List[int] num_computed_tokens: int output_token_ids: List[int] @@ -618,6 +619,7 @@ class CachedRequestState: def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) + class InputBatch: def __init__( From d3e9bb708021397d7b12eee39b4151675b131a77 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 2 Dec 2024 18:48:12 +0000 Subject: [PATCH 06/39] Comments --- tests/v1/sample/test_sampler.py | 171 ++++++++++++++++++++++---------- vllm/utils.py | 3 + vllm/v1/sample/sampler.py | 21 ++-- 3 files changed, 132 insertions(+), 63 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 96c64b857794c..7c7b276e33d41 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -1,9 +1,11 @@ +from typing import List, Set, Tuple + +import numpy as np import pytest import torch + from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler -from typing import List, Set, Tuple -import numpy as np VOCAB_SIZE = 1024 CUDA_DEVICES = [ @@ -12,32 +14,32 @@ MAX_NUM_PROMPT_TOKENS = 64 -def _create_fake_logits( - batch_size: int, vocab_size:int -) -> torch.Tensor: - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - dtype=torch.float) +def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: + fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float) return fake_logits + def _create_default_sampling_metadata( - num_output_tokens: int, batch_size: int, - vocab_size: int, + num_output_tokens: int, + batch_size: int, + vocab_size: int, ) -> SamplingMetadata: - output_token_ids:List[List[int]] = [] - prompt_token_ids:List[List[int]] = [] + output_token_ids: List[List[int]] = [] + prompt_token_ids: List[List[int]] = [] for _ in range(batch_size): output_token_ids.append( np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) prompt_token_ids.append( - np.random.randint(0, vocab_size, - size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)).tolist()) + np.random.randint(0, + vocab_size, + size=np.random.randint( + 1, MAX_NUM_PROMPT_TOKENS)).tolist()) fake_sampling_metadata = SamplingMetadata( - temperature=torch.full((batch_size,), 0.0), + temperature=torch.full((batch_size, ), 0.0), all_greedy=True, all_random=False, - top_p=torch.empty(batch_size,), - top_k=torch.empty(batch_size,), + top_p=torch.empty(batch_size, ), + top_k=torch.empty(batch_size, ), no_top_p=True, no_top_k=True, generators={}, @@ -52,39 +54,63 @@ def _create_default_sampling_metadata( ) return fake_sampling_metadata + def _create_min_token_penalty_dataset( - num_output_tokens: int, - batch_size: int, - vocab_size: int, - batch_indices_for_min_token_penalty:List[int] + num_output_tokens: int, batch_size: int, vocab_size: int, + batch_indices_for_min_token_penalty: List[int] ) -> Tuple[List[int], List[Set[int]]]: - stop_token_ids:List[Set[int]] = [] - min_tokens: List[int]=[] + """ + + """ + stop_token_ids: List[Set[int]] = [] + min_tokens: List[int] = [] for index in range(batch_size): if index in batch_indices_for_min_token_penalty: min_tokens.append( - np.random.randint(num_output_tokens + 1, 2 * num_output_tokens)) + np.random.randint(num_output_tokens + 1, + 2 * num_output_tokens)) stop_token_ids.append( - set(np.random.randint(0, vocab_size - 1) for _ in range( - np.random.randint(0, vocab_size)))) + set( + np.random.randint(0, vocab_size - 1) + for _ in range(np.random.randint(0, vocab_size)))) else: min_tokens.append(np.random.randint(0, num_output_tokens)) stop_token_ids.append(set()) return (min_tokens, stop_token_ids) + def _create_weighted_output_token_list( - batch_size: int, - vocab_size: int -) -> Tuple[List[List[int]], List[List[int]]]: - output_token_ids : List[List[int]] = [] - sorted_token_ids_in_output : List[List[int]] = [] + batch_size: int, + vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]: + """ + Creates an output token list where each token occurs a distinct + number of times. + + For each batch, a random subset of token IDs is selected from the + vocabulary. The selected tokens are then added to the output token + list, each with a different frequency. + + Returns: + Tuple[List[List[int]], List[List[int]]]: + - The first element is the output token list, where each sublist + corresponds to a batch and contains tokens with weighted + frequencies. + - The second element is a list of distinct token IDs for each + batch, ordered by their frequency in the corresponding output + list. + """ + output_token_ids: List[List[int]] = [] + sorted_token_ids_in_output: List[List[int]] = [] for _ in range(batch_size): - distinct_token_ids = np.random.choice(vocab_size, size=np.random.randint(1, 10), replace=False).tolist() + distinct_token_ids = np.random.choice(vocab_size, + size=np.random.randint(1, 10), + replace=False).tolist() sorted_token_ids_in_output.append(distinct_token_ids) output_token_ids_for_batch = [] for index, token_id in enumerate(distinct_token_ids): - output_token_ids_for_batch.extend([token_id for _ in range(index+1) ]) + output_token_ids_for_batch.extend( + [token_id for _ in range(index + 1)]) output_token_ids.append(output_token_ids_for_batch) return (output_token_ids, sorted_token_ids_in_output) @@ -100,77 +126,120 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): torch.set_default_device(device) NUM_OUTPUT_TOKENS = 20 fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) - sampling_metadata= _create_default_sampling_metadata( + sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) batch_indices_for_min_token_penalty = np.random.randint( - 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() + 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() min_tokens, stop_token_ids = _create_min_token_penalty_dataset( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, + batch_indices_for_min_token_penalty) sampling_metadata.min_tokens = min_tokens sampling_metadata.stop_token_ids = stop_token_ids sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): for vocab in range(VOCAB_SIZE): + # Verify that the logprobs for stop token ids is set + # to -inf. logprob_index = torch.where( - sampler_output.logprob_token_ids[batch_idx] == vocab)[0].item() + sampler_output.logprob_token_ids[batch_idx] == + vocab)[0].item() if vocab in stop_token_ids[batch_idx]: - assert sampler_output.logprobs[batch_idx][logprob_index] == -float("inf") + assert sampler_output.logprobs[batch_idx][ + logprob_index] == -float("inf") else: - assert sampler_output.logprobs[batch_idx][logprob_index] != -float("inf") + assert sampler_output.logprobs[batch_idx][ + logprob_index] != -float("inf") + @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) def test_sampler_presence_penalty(device: str, batch_size: int): + """ + Test to verify that if presence penalty is enabled then tokens + which are already present in the output are penalized vs tokens + which are not yet present in the output. + """ torch.set_default_device(device) NUM_OUTPUT_TOKENS = 20 + # Create fake logits where each token is assigned the same + # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) - sampling_metadata= _create_default_sampling_metadata( + sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) output_token_ids = sampling_metadata.output_token_ids sampling_metadata.presence_penalties = [2.0 for _ in range(batch_size)] sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): - logprob_for_output_token = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] + # The logprobs in the SamplerOutput are arranged in descending order. + # Since all tokens initially have the same logprobs, the non-penalized + # tokens (i.e., token IDs not present in the output) will appear at + # the beginning, while the penalized tokens (i.e., token IDs present in + # the output) will appear at the end of the list. + logprob_for_output_token = sampler_output.logprobs[batch_idx][ + VOCAB_SIZE - 1] logprob_for_non_output_token = sampler_output.logprobs[batch_idx][0] assert logprob_for_non_output_token > logprob_for_output_token for vocab in range(VOCAB_SIZE): logprob_index = torch.where( - sampler_output.logprob_token_ids[batch_idx] == vocab)[0].item() + sampler_output.logprob_token_ids[batch_idx] == + vocab)[0].item() if vocab in output_token_ids[batch_idx]: + # This token is present in the list of already output tokens. + # Hence it must have been penalized by the presence penalty. + # Verify that the logprob of this token is same as that + # expected for penalized tokens. assert torch.isclose( sampler_output.logprobs[batch_idx][logprob_index], logprob_for_output_token) else: + # This token is not present in the list of already output + # tokens. Hence it has not been penalized. + # Verify that the logprob of this token is same as that + # expected for non-penalized tokens. assert torch.isclose( sampler_output.logprobs[batch_idx][logprob_index], logprob_for_non_output_token) + @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) def test_sampler_frequency_penalty(device: str, batch_size: int): """ - Test to verify that if fre + Test to verify that if frequency penalty is enabled then tokens with + higher frequency are penalized more than those with lower frequency. """ torch.set_default_device(device) NUM_OUTPUT_TOKENS = 20 + # Create fake logits where each token is assigned the same + # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) - sampling_metadata= _create_default_sampling_metadata( + sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) sampling_metadata.frequency_penalties = [2.0 for _ in range(batch_size)] output_token_ids, sorted_token_ids_in_output = \ _create_weighted_output_token_list(batch_size, VOCAB_SIZE) - sampling_metadata.output_token_ids=output_token_ids + sampling_metadata.output_token_ids = output_token_ids sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] token_ids_in_output = sorted_token_ids_in_output[batch_idx] - assert not torch.isin( - logprobs_token_ids[ : -len(token_ids_in_output)], - torch.tensor(token_ids_in_output)).any(), "Some values in the tensor are in the list" - assert logprobs_token_ids[-len(token_ids_in_output):].tolist() == token_ids_in_output, \ - "The tensor values are not in the same order as the list!" - + # The logprobs in the SamplerOutput are arranged in descending order. + # Initially, all tokens have the same logprob, so their order reflects + # their frequency in the existing output. Tokens that occur least + # frequently in the current output appear at the beginning of the list, + # while tokens that occur most frequently are placed at the end. + # Verify that tokens not present in the current output are ranked + # higher (i.e., appear earlier) than tokens already present in the + # output. + assert not torch.isin(logprobs_token_ids[:-len(token_ids_in_output)], + torch.tensor(token_ids_in_output)).any( + ), "Some values in the tensor are in the list" + # Verify that the logprobs of tokens already present in the output + # decrease as their frequency of occurrence increases. + assert logprobs_token_ids[-len(token_ids_in_output):].tolist() \ + == token_ids_in_output, \ + "The tensor values are not in the same order as the list!" diff --git a/vllm/utils.py b/vllm/utils.py index 14265ff349a9e..4ebe9b89e9562 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1663,6 +1663,9 @@ def apply_sampling_penalties( output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, repetition_penalties: torch.Tensor) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ num_seqs, vocab_size = logits.shape _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, num_seqs) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index c096c1ff3abb2..f6a6feaad34b3 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,5 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict, List, Tuple, Set +from typing import Dict, List, Set, Tuple import torch import torch.nn as nn @@ -18,15 +18,14 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - _apply_min_token_penalties(logits, - sampling_metadata.output_token_ids, + _apply_min_token_penalties(logits, sampling_metadata.output_token_ids, sampling_metadata.stop_token_ids, sampling_metadata.min_tokens) _apply_penalties(logits, sampling_metadata.prompt_token_ids, sampling_metadata.output_token_ids, sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties) + sampling_metadata.repetition_penalties) logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_top_k_top_p(logits, sampling_metadata) probs = self.get_probs(logits) @@ -170,7 +169,7 @@ def _apply_top_k_top_p( def _apply_min_token_penalties(logits: torch.Tensor, output_token_ids: List[List[int]], stop_token_ids: List[Set[int]], - min_tokens: List[int]) -> torch.Tensor: + min_tokens: List[int]): """ Applies minimum token penalty by setting the logits of the stop tokens to -inf. @@ -182,13 +181,13 @@ def _apply_min_token_penalties(logits: torch.Tensor, min_tokens_logits_to_penalize.append((index, stop_token_id)) if min_tokens_logits_to_penalize: logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") - return logits + def _apply_penalties(logits: torch.Tensor, prompt_token_ids: List[List[int]], output_token_ids: List[List[int]], presence_penalties: List[float], frequency_penalties: List[float], - repetition_penalties: List[float]) -> torch.Tensor: + repetition_penalties: List[float]): """ Applies presence, frequency and repetition penalties to the logits. """ @@ -204,11 +203,9 @@ def _apply_penalties(logits: torch.Tensor, prompt_token_ids: List[List[int]], prompt_token_ids, output_token_ids, frequency_penalties, presence_penalties, repetition_penalties, vocab_size, logits.device) - return apply_sampling_penalties(logits, prompt_tokens_t, - output_tokens_t, presence_penalties_t, - frequency_penalties_t, - repetition_penalties_t) - return logits + apply_sampling_penalties(logits, prompt_tokens_t, output_tokens_t, + presence_penalties_t, frequency_penalties_t, + repetition_penalties_t) def _convert_to_tensors(prompt_token_ids: List[List[int]], From e3468fecbceb4dac7318757b933e41ac875bd5c2 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 2 Dec 2024 21:38:16 +0000 Subject: [PATCH 07/39] Tests --- tests/v1/engine/test_engine_core.py | 30 +++++++++++++++++++++++++++ tests/v1/sample/test_sampler.py | 32 ++++++++++++++++++++++++++--- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index b3692b594326a..e44f92b67ab7d 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -138,3 +138,33 @@ def test_engine_core(monkeypatch): engine_core.abort_requests([req2.request_id, req0.request_id]) assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 + + +def test_engine_core_advanced_sampling(monkeypatch): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + """Setup the EngineCore.""" + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config() + executor_class = AsyncLLM._get_executor_cls(vllm_config) + + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=executor_class, + usage_context=UsageContext.UNKNOWN_CONTEXT) + """Test basic request lifecycle.""" + # First request. + request: EngineCoreRequest = make_request() + request.sampling_params = SamplingParams(min_tokens=4, + presence_penalty=1.0, + frequency_penalty=1.0, + repetition_penalty=0.1, + stop_token_ids=[1001, 1002]) + engine_core.add_request(make_request()) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 + # Loop through until they are all done. + while len(engine_core.step()) > 0: + pass + + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 7c7b276e33d41..090e94a410e0c 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -8,6 +8,7 @@ from vllm.v1.sample.sampler import Sampler VOCAB_SIZE = 1024 +NUM_OUTPUT_TOKENS = 20 CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] @@ -124,7 +125,6 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): the stop token ids to -inf. """ torch.set_default_device(device) - NUM_OUTPUT_TOKENS = 20 fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) @@ -161,7 +161,6 @@ def test_sampler_presence_penalty(device: str, batch_size: int): which are not yet present in the output. """ torch.set_default_device(device) - NUM_OUTPUT_TOKENS = 20 # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) @@ -211,7 +210,6 @@ def test_sampler_frequency_penalty(device: str, batch_size: int): higher frequency are penalized more than those with lower frequency. """ torch.set_default_device(device) - NUM_OUTPUT_TOKENS = 20 # Create fake logits where each token is assigned the same # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) @@ -243,3 +241,31 @@ def test_sampler_frequency_penalty(device: str, batch_size: int): assert logprobs_token_ids[-len(token_ids_in_output):].tolist() \ == token_ids_in_output, \ "The tensor values are not in the same order as the list!" + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +def test_sampler_repetition_penalty(device: str, batch_size: int): + """ + Test to verify that if frequency penalty is enabled then tokens with + higher frequency are penalized more than those with lower frequency. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) + sampling_metadata.repetition_penalties = [2.0 for _ in range(batch_size)] + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] + assert (logprobs_token_ids[0] not in \ + sampling_metadata.prompt_token_ids[batch_idx] and \ + logprobs_token_ids[0] not in \ + sampling_metadata.output_token_ids[batch_idx]) + assert (logprobs_token_ids[VOCAB_SIZE-1] in \ + sampling_metadata.prompt_token_ids[batch_idx] or \ + logprobs_token_ids[VOCAB_SIZE-1] in \ + sampling_metadata.output_token_ids[batch_idx]) From 47c4b749d5e5eae3c8424560d7dc3a85392f7925 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 3 Dec 2024 02:51:43 +0000 Subject: [PATCH 08/39] Fixes Signed-off-by: Sourashis Roy --- tests/v1/engine/test_engine_core.py | 5 + tests/v1/sample/test_sampler.py | 154 ++++++++++++++++------------ 2 files changed, 95 insertions(+), 64 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index e44f92b67ab7d..e855ede4b95a4 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -141,6 +141,11 @@ def test_engine_core(monkeypatch): def test_engine_core_advanced_sampling(monkeypatch): + """ + A basic end-to-end test to verify that the engine functions correctly + when additional sampling parameters, such as min_tokens and + presence_penalty, are set. + """ with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") """Setup the EngineCore.""" diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 090e94a410e0c..850067bcf3cc6 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -154,11 +154,12 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) -def test_sampler_presence_penalty(device: str, batch_size: int): +@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) +def test_sampler_presence_penalty(device: str, batch_size: int, + presence_penalty: float): """ Test to verify that if presence penalty is enabled then tokens - which are already present in the output are penalized vs tokens - which are not yet present in the output. + are penalized as per their presence in the existing output. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same @@ -167,47 +168,47 @@ def test_sampler_presence_penalty(device: str, batch_size: int): sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) output_token_ids = sampling_metadata.output_token_ids - sampling_metadata.presence_penalties = [2.0 for _ in range(batch_size)] + sampling_metadata.presence_penalties = [ + presence_penalty for _ in range(batch_size) + ] sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): # The logprobs in the SamplerOutput are arranged in descending order. # Since all tokens initially have the same logprobs, the non-penalized - # tokens (i.e., token IDs not present in the output) will appear at - # the beginning, while the penalized tokens (i.e., token IDs present in - # the output) will appear at the end of the list. - logprob_for_output_token = sampler_output.logprobs[batch_idx][ + # tokens will appear at the beginning, while the penalized tokens + # will appear at the end of the list. + penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ VOCAB_SIZE - 1] - logprob_for_non_output_token = sampler_output.logprobs[batch_idx][0] - assert logprob_for_non_output_token > logprob_for_output_token - for vocab in range(VOCAB_SIZE): - logprob_index = torch.where( - sampler_output.logprob_token_ids[batch_idx] == - vocab)[0].item() - if vocab in output_token_ids[batch_idx]: - # This token is present in the list of already output tokens. - # Hence it must have been penalized by the presence penalty. - # Verify that the logprob of this token is same as that - # expected for penalized tokens. - assert torch.isclose( - sampler_output.logprobs[batch_idx][logprob_index], - logprob_for_output_token) - else: - # This token is not present in the list of already output - # tokens. Hence it has not been penalized. - # Verify that the logprob of this token is same as that - # expected for non-penalized tokens. - assert torch.isclose( - sampler_output.logprobs[batch_idx][logprob_index], - logprob_for_non_output_token) + penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] + non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0] + non_penalized_log_prod = sampler_output.logprobs[batch_idx][0] + assert non_penalized_log_prod > penalized_log_prod + if presence_penalty > 0: + # If `presence_penalty` is set to a value greater than 0, it + # indicates a preference for new tokens over those already + # present in the output. + # Verify that the penalized token ID exists in the output, while the + # non-penalized token ID does not. + assert penalized_token_id in output_token_ids[batch_idx] + assert non_penalized_token_id not in output_token_ids[batch_idx] + elif presence_penalty < 0: + # If `presence_penalty` is set to a value less than 0, it indicates + # a preference for existing tokens over new ones. Verify that the + # non-penalized token ID exists in the output, while the penalized + # token ID does not. + assert non_penalized_token_id in output_token_ids[batch_idx] + assert penalized_token_id not in output_token_ids[batch_idx] @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) -def test_sampler_frequency_penalty(device: str, batch_size: int): +@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) +def test_sampler_frequency_penalty(device: str, batch_size: int, + frequency_penalty: float): """ - Test to verify that if frequency penalty is enabled then tokens with - higher frequency are penalized more than those with lower frequency. + Test to verify that if frequency penalty is enabled then tokens are + penalized as per their frequency of occurrence. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same @@ -215,7 +216,9 @@ def test_sampler_frequency_penalty(device: str, batch_size: int): fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) - sampling_metadata.frequency_penalties = [2.0 for _ in range(batch_size)] + sampling_metadata.frequency_penalties = [ + frequency_penalty for _ in range(batch_size) + ] output_token_ids, sorted_token_ids_in_output = \ _create_weighted_output_token_list(batch_size, VOCAB_SIZE) sampling_metadata.output_token_ids = output_token_ids @@ -223,32 +226,41 @@ def test_sampler_frequency_penalty(device: str, batch_size: int): sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] - token_ids_in_output = sorted_token_ids_in_output[batch_idx] - # The logprobs in the SamplerOutput are arranged in descending order. - # Initially, all tokens have the same logprob, so their order reflects - # their frequency in the existing output. Tokens that occur least - # frequently in the current output appear at the beginning of the list, - # while tokens that occur most frequently are placed at the end. - - # Verify that tokens not present in the current output are ranked - # higher (i.e., appear earlier) than tokens already present in the - # output. - assert not torch.isin(logprobs_token_ids[:-len(token_ids_in_output)], - torch.tensor(token_ids_in_output)).any( - ), "Some values in the tensor are in the list" - # Verify that the logprobs of tokens already present in the output - # decrease as their frequency of occurrence increases. - assert logprobs_token_ids[-len(token_ids_in_output):].tolist() \ - == token_ids_in_output, \ - "The tensor values are not in the same order as the list!" + non_penalized_token_id = logprobs_token_ids[0] + penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + distinct_sorted_token_ids_in_output = \ + sorted_token_ids_in_output[batch_idx] + most_frequent_token_id = distinct_sorted_token_ids_in_output[ + len(distinct_sorted_token_ids_in_output) - 1] + if frequency_penalty > 0: + # If `frequency_penalty` is set to > 0, it indicates + # a preference for new tokens over existing ones. Verify that the + # non-penalized token ID is not present in the output, while the + # most penalized token is the one that occurs most frequently in + # the output. + assert non_penalized_token_id \ + not in distinct_sorted_token_ids_in_output + assert penalized_token_id == most_frequent_token_id + elif frequency_penalty < 0: + # If `frequency_penalty` is set to < 0, it indicates + # a preference for existing tokens over new ones. Verify that the + # non-penalized token ID is the one that occurs most frequently + # in the output, while the penalized token ID is one that has not + # yet appeared. + assert non_penalized_token_id == most_frequent_token_id + assert penalized_token_id \ + not in distinct_sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) -def test_sampler_repetition_penalty(device: str, batch_size: int): +@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) +def test_sampler_repetition_penalty(device: str, batch_size: int, + repetition_penalty: float): """ - Test to verify that if frequency penalty is enabled then tokens with - higher frequency are penalized more than those with lower frequency. + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. """ torch.set_default_device(device) # Create fake logits where each token is assigned the same @@ -256,16 +268,30 @@ def test_sampler_repetition_penalty(device: str, batch_size: int): fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) - sampling_metadata.repetition_penalties = [2.0 for _ in range(batch_size)] + sampling_metadata.repetition_penalties = [ + repetition_penalty for _ in range(batch_size) + ] sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] - assert (logprobs_token_ids[0] not in \ - sampling_metadata.prompt_token_ids[batch_idx] and \ - logprobs_token_ids[0] not in \ - sampling_metadata.output_token_ids[batch_idx]) - assert (logprobs_token_ids[VOCAB_SIZE-1] in \ - sampling_metadata.prompt_token_ids[batch_idx] or \ - logprobs_token_ids[VOCAB_SIZE-1] in \ - sampling_metadata.output_token_ids[batch_idx]) + non_penalized_token_id = logprobs_token_ids[0] + penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + prompt_tokens = sampling_metadata.prompt_token_ids[batch_idx] + output_tokens = sampling_metadata.output_token_ids[batch_idx] + if repetition_penalty > 1.0: + # If `repetition_penalty` > 1.0, verify that the non-penalized + # token ID has not been seen before, while the penalized token ID + # exists either in the prompt or the output. + assert (non_penalized_token_id not in prompt_tokens and \ + non_penalized_token_id not in output_tokens) + assert (penalized_token_id in prompt_tokens or \ + penalized_token_id in output_tokens) + elif repetition_penalty < 1.0: + # If `repetition_penalty` < 1.0, verify that the penalized + # token ID has not been seen before, while the non-penalized + # token ID exists either in the prompt or the output. + assert (penalized_token_id not in prompt_tokens and \ + penalized_token_id not in output_tokens) + assert (non_penalized_token_id in prompt_tokens or \ + non_penalized_token_id in output_tokens) From 35ac8bc9d17240265d94a50916cf82b06b0e1c39 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 3 Dec 2024 13:33:59 +0000 Subject: [PATCH 09/39] Fix tests --- tests/v1/engine/test_engine_core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index e855ede4b95a4..5f85860c74960 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -163,8 +163,9 @@ def test_engine_core_advanced_sampling(monkeypatch): presence_penalty=1.0, frequency_penalty=1.0, repetition_penalty=0.1, - stop_token_ids=[1001, 1002]) - engine_core.add_request(make_request()) + stop_token_ids=[1001, 1002], + ) + engine_core.add_request(request) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 # Loop through until they are all done. From cce842895da5ba7b3fe6c3fcda74f537d18d74e2 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 3 Dec 2024 14:24:28 +0000 Subject: [PATCH 10/39] Fixes Signed-off-by: Sourashis Roy --- tests/v1/engine/test_engine_core.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 5f85860c74960..2e84ec9a53901 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -159,12 +159,13 @@ def test_engine_core_advanced_sampling(monkeypatch): """Test basic request lifecycle.""" # First request. request: EngineCoreRequest = make_request() - request.sampling_params = SamplingParams(min_tokens=4, - presence_penalty=1.0, - frequency_penalty=1.0, - repetition_penalty=0.1, - stop_token_ids=[1001, 1002], - ) + request.sampling_params = SamplingParams( + min_tokens=4, + presence_penalty=1.0, + frequency_penalty=1.0, + repetition_penalty=0.1, + stop_token_ids=[1001, 1002], + ) engine_core.add_request(request) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 From 9febfbf7fded5875d77e558b849e5f78b38b916b Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 5 Dec 2024 13:50:28 +0000 Subject: [PATCH 11/39] Fixes --- tests/v1/sample/test_sampler.py | 78 +++++++++++---- vllm/model_executor/layers/sampler.py | 36 +++++-- vllm/utils.py | 24 ----- vllm/v1/sample/metadata.py | 12 ++- vllm/v1/sample/sampler.py | 97 ++++++++---------- vllm/v1/worker/gpu_model_runner.py | 135 +++++++++++++++++++++++--- 6 files changed, 252 insertions(+), 130 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 850067bcf3cc6..6b544fe0a3652 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -4,6 +4,7 @@ import pytest import torch +from vllm.utils import make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -20,10 +21,34 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: return fake_logits +def _create_penalty_tensor(batch_size: int, vocab_size: int, + penalty_value: float, + device: torch.device) -> torch.Tensor: + return torch.full((batch_size, vocab_size), + fill_value=penalty_value, + dtype=torch.float, + device=device) + + +def _create_prompt_tokens_tensor( + prompt_token_ids: List[List[int]], + vocab_size: int, + device: torch.device, +) -> torch.Tensor: + return make_tensor_with_pad( + prompt_token_ids, + pad=vocab_size, + device=device, + dtype=torch.int64, + pin_memory=False, + ) + + def _create_default_sampling_metadata( num_output_tokens: int, batch_size: int, vocab_size: int, + device: torch.device, ) -> SamplingMetadata: output_token_ids: List[List[int]] = [] prompt_token_ids: List[List[int]] = [] @@ -45,23 +70,35 @@ def _create_default_sampling_metadata( no_top_k=True, generators={}, max_num_logprobs=VOCAB_SIZE, - prompt_token_ids=prompt_token_ids, + prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, + vocab_size, device), output_token_ids=output_token_ids, - frequency_penalties=[0.0 for _ in range(batch_size)], - presence_penalties=[0.0 for _ in range(batch_size)], - repetition_penalties=[1.0 for _ in range(batch_size)], + frequency_penalties=_create_penalty_tensor(batch_size, vocab_size, 0.0, + device), + presence_penalties=_create_penalty_tensor(batch_size, vocab_size, 0.0, + device), + repetition_penalties=_create_penalty_tensor(batch_size, vocab_size, + 1.0, device), + no_penalties=True, min_tokens=[], stop_token_ids=[], ) return fake_sampling_metadata -def _create_min_token_penalty_dataset( +def _generate_min_token_penalties_and_stop_tokens( num_output_tokens: int, batch_size: int, vocab_size: int, batch_indices_for_min_token_penalty: List[int] ) -> Tuple[List[int], List[Set[int]]]: """ + Generates and returns a list of minimum token penalties (`min_tokens`) + and a corresponding list of stop token IDs (`stop_token_ids`) for each + batch. + If a batch index is included in `batch_indices_for_min_token_penalty`, + a higher `min_tokens` value is assigned (within a randomized range), + and a random set of stop token IDs is created. Otherwise, a lower + `min_tokens` value is assigned, and the stop token IDs set is empty. """ stop_token_ids: List[Set[int]] = [] min_tokens: List[int] = [] @@ -127,10 +164,10 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): torch.set_default_device(device) fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) batch_indices_for_min_token_penalty = np.random.randint( 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() - min_tokens, stop_token_ids = _create_min_token_penalty_dataset( + min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty) sampling_metadata.min_tokens = min_tokens @@ -166,11 +203,11 @@ def test_sampler_presence_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) output_token_ids = sampling_metadata.output_token_ids - sampling_metadata.presence_penalties = [ - presence_penalty for _ in range(batch_size) - ] + sampling_metadata.presence_penalties = _create_penalty_tensor( + batch_size, VOCAB_SIZE, presence_penalty, torch.device(device)) + sampling_metadata.no_penalties = False sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): @@ -215,13 +252,13 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) - sampling_metadata.frequency_penalties = [ - frequency_penalty for _ in range(batch_size) - ] + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.frequency_penalties = _create_penalty_tensor( + batch_size, VOCAB_SIZE, frequency_penalty, torch.device(device)) output_token_ids, sorted_token_ids_in_output = \ _create_weighted_output_token_list(batch_size, VOCAB_SIZE) sampling_metadata.output_token_ids = output_token_ids + sampling_metadata.no_penalties = False sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): @@ -267,17 +304,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE) - sampling_metadata.repetition_penalties = [ - repetition_penalty for _ in range(batch_size) - ] + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.repetition_penalties = _create_penalty_tensor( + batch_size, VOCAB_SIZE, repetition_penalty, torch.device(device)) + sampling_metadata.no_penalties = False sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) for batch_idx in range(batch_size): logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] non_penalized_token_id = logprobs_token_ids[0] penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] - prompt_tokens = sampling_metadata.prompt_token_ids[batch_idx] + prompt_tokens = sampling_metadata.prompt_token_ids[ + batch_idx][:].tolist() output_tokens = sampling_metadata.output_token_ids[batch_idx] if repetition_penalty > 1.0: # If `repetition_penalty` > 1.0, verify that the non-penalized diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 5c3ccd74ad8f8..cb32ed5f3444b 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -19,7 +19,7 @@ CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.utils import apply_sampling_penalties +from vllm.utils import get_token_bin_counts_and_mask if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -259,12 +259,11 @@ def forward( # Apply presence and frequency penalties. if do_penalties: - logits = apply_sampling_penalties( - logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. @@ -385,6 +384,29 @@ def _apply_min_tokens_penalty( return logits +def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) + repetition_penalties[~(prompt_mask | output_mask)] = 1.0 + logits = torch.where(logits > 0, logits / repetition_penalties, + logits * repetition_penalties) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + return logits + + def _apply_top_k_top_p( logits: torch.Tensor, p: torch.Tensor, diff --git a/vllm/utils.py b/vllm/utils.py index 4ebe9b89e9562..9c0c53aecbf16 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1656,27 +1656,3 @@ def get_token_bin_counts_and_mask( mask = bin_counts > 0 return bin_counts, mask - - -def apply_sampling_penalties( - logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: - """ - Applies presence, frequency and repetition penalties to the logits. - """ - num_seqs, vocab_size = logits.shape - _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, - vocab_size, num_seqs) - output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~(prompt_mask | output_mask)] = 1.0 - logits[logits > 0] /= repetition_penalties[logits > 0] - logits[logits <= 0] *= repetition_penalties[logits <= 0] - # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask - return logits diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 0503012088928..d60f7eb5d76f9 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Set +from typing import Dict, List, Optional, Set import torch @@ -20,10 +20,12 @@ class SamplingMetadata: max_num_logprobs: int + no_penalties: bool + prompt_token_ids: Optional[torch.Tensor] + frequency_penalties: torch.Tensor + presence_penalties: torch.Tensor + repetition_penalties: torch.Tensor + output_token_ids: List[List[int]] - prompt_token_ids: List[List[int]] - frequency_penalties: List[float] - presence_penalties: List[float] - repetition_penalties: List[float] min_tokens: List[int] stop_token_ids: List[Set[int]] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index f6a6feaad34b3..54ac448e30174 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn -from vllm.utils import apply_sampling_penalties, make_tensor_with_pad +from vllm.utils import (get_token_bin_counts_and_mask, is_pin_memory_available, + make_tensor_with_pad) from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -21,11 +22,13 @@ def forward( _apply_min_token_penalties(logits, sampling_metadata.output_token_ids, sampling_metadata.stop_token_ids, sampling_metadata.min_tokens) - _apply_penalties(logits, sampling_metadata.prompt_token_ids, - sampling_metadata.output_token_ids, - sampling_metadata.presence_penalties, - sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + _apply_penalties(logits, sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids) logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_top_k_top_p(logits, sampling_metadata) probs = self.get_probs(logits) @@ -183,67 +186,43 @@ def _apply_min_token_penalties(logits: torch.Tensor, logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") -def _apply_penalties(logits: torch.Tensor, prompt_token_ids: List[List[int]], - output_token_ids: List[List[int]], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float]): +def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: List[List[int]]): """ Applies presence, frequency and repetition penalties to the logits. """ - apply_penalties = any(p != 0.0 for p in presence_penalties) or any( - f != 0.0 - for f in frequency_penalties) or any(r != 1.0 - for r in repetition_penalties) - if apply_penalties: - _, vocab_size = logits.shape - (prompt_tokens_t, output_tokens_t, frequency_penalties_t, - presence_penalties_t, repetition_penalties_t) = \ - _convert_to_tensors( - prompt_token_ids, output_token_ids, frequency_penalties, - presence_penalties, repetition_penalties, vocab_size, - logits.device) - apply_sampling_penalties(logits, prompt_tokens_t, output_tokens_t, - presence_penalties_t, frequency_penalties_t, - repetition_penalties_t) - - -def _convert_to_tensors(prompt_token_ids: List[List[int]], - output_token_ids: List[List[int]], - frequency_penalties: List[float], - presence_penalties: List[float], - repetition_penalties: List[float], vocab_size: int, - device: torch.device) -> Tuple[torch.Tensor, ...]: + num_seqs, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + _, prompt_mask = get_token_bin_counts_and_mask(prompt_token_ids, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_t, vocab_size, num_seqs) + logits[logits > 0] /= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits > 0] + logits[logits <= 0] *= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits <= 0] + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties * output_bin_counts + logits -= presence_penalties * output_mask + + return logits + + +def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, + device: torch.device) -> torch.Tensor: """ Convert the different list data structures to tensors. """ - prompt_tokens_tensor = make_tensor_with_pad( - prompt_token_ids, - vocab_size, - device=device, - dtype=torch.int64, - ) output_tokens_tensor = make_tensor_with_pad( output_token_ids, vocab_size, - device=device, + device="cpu", dtype=torch.int64, + pin_memory=is_pin_memory_available(), ) - frequency_penalties_tensor = torch.tensor( - frequency_penalties, - device=device, - dtype=torch.float, - ) - presence_penalties_tensor = torch.tensor( - presence_penalties, - device=device, - dtype=torch.float, - ) - repetition_penalties_tensor = torch.tensor( - repetition_penalties, - device=device, - dtype=torch.float, - ) - return (prompt_tokens_tensor, output_tokens_tensor, - frequency_penalties_tensor, presence_penalties_tensor, - repetition_penalties_tensor) + return output_tokens_tensor.to(device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e981514400368..67cdb0183751a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -17,7 +17,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, - is_pin_memory_available) + is_pin_memory_available, make_tensor_with_pad) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput @@ -93,6 +93,7 @@ def __init__( max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), ) self.use_cuda_graph = (self.vllm_config.compilation_config.level @@ -466,6 +467,7 @@ def execute_model( logits=logits, sampling_metadata=sampling_metadata, ) + # Update the # NOTE: CPU-GPU synchronization happens here. sampled_token_ids = sampler_output.sampled_token_ids.cpu() @@ -629,12 +631,14 @@ def __init__( max_num_blocks_per_req: int, device: torch.device, pin_memory: bool, + vocab_size: int, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory + self.vocab_size = vocab_size self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_id_to_index: Dict[str, int] = {} @@ -687,6 +691,44 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() + self.frequency_penalties = torch.empty((max_num_reqs, vocab_size), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, vocab_size), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_reqs: Set[str] = set() + + self.presence_penalties = torch.empty((max_num_reqs, vocab_size), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty( + (max_num_reqs, vocab_size), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = \ + self.presence_penalties_cpu_tensor.numpy() + self.presence_penalties_reqs: Set[str] = set() + + self.repetition_penalties = torch.empty((max_num_reqs, vocab_size), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, vocab_size), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu =\ + self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_reqs: Set[str] = set() + + self.prompt_tokens_tensor: Optional[torch.Tensor] = None + # req_index -> generator self.generators: Dict[int, torch.Generator] = {} @@ -732,6 +774,18 @@ def add_request( self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) + self.frequency_penalties_cpu[req_index][:] =\ + sampling_params.frequency_penalty + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[req_index][:] = \ + sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[req_index][:] = \ + sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) self.generators[req_index] = request.generator @@ -751,6 +805,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) @@ -763,6 +820,9 @@ def clear(self) -> None: self.random_reqs.clear() self.top_p_reqs.clear() self.top_k_reqs.clear() + self.frequency_penalties_reqs.clear() + self.presence_penalties_reqs.clear() + self.repetition_penalties_reqs.clear() self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() @@ -803,6 +863,13 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index][:] = \ + self.frequency_penalties_cpu[last_req_index][:] + self.presence_penalties_cpu[empty_index][:] = \ + self.presence_penalties_cpu[last_req_index][:] + self.repetition_penalties[empty_index][:] = \ + self.repetition_penalties[last_req_index][:] + generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator @@ -822,25 +889,35 @@ def make_sampling_metadata( self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k[:self.num_reqs].copy_( self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + if not self.no_reqs_with_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + self.frequency_penalties[:self.num_reqs].copy_( + self.frequency_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.presence_penalties[:self.num_reqs].copy_( + self.presence_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.repetition_penalties[:self.num_reqs].copy_( + self.repetition_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + self.prompt_tokens_tensor = \ + self._construct_prompt_tokens_tensor( + requests, self.vocab_size, device=self.device) output_token_ids: List[List[int]] = [] - prompt_token_ids: List[List[int]] = [] - frequency_penalties: List[float] = [] - presence_penalties: List[float] = [] - repetition_penalties: List[float] = [] min_tokens: List[int] = [] stop_token_ids: List[set[int]] = [] for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None request = requests[req_id] + # Currently we create a tensor from the output_token_ids.append(request.output_token_ids) - prompt_token_ids.append(request.prompt_token_ids) - frequency_penalties.append( - request.sampling_params.frequency_penalty) - presence_penalties.append(request.sampling_params.presence_penalty) - repetition_penalties.append( - request.sampling_params.repetition_penalty) min_tokens.append(request.sampling_params.min_tokens) stop_token_ids.append(request.sampling_params.all_stop_token_ids) @@ -854,14 +931,36 @@ def make_sampling_metadata( no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=prompt_token_ids, + prompt_token_ids=self.prompt_tokens_tensor[:self.num_reqs] \ + if self.prompt_tokens_tensor is not None else None, + frequency_penalties=self.frequency_penalties[:self.num_reqs], + presence_penalties=self.presence_penalties[:self.num_reqs], + repetition_penalties=self.repetition_penalties[:self.num_reqs], output_token_ids=output_token_ids, - frequency_penalties=frequency_penalties, - presence_penalties=presence_penalties, - repetition_penalties=repetition_penalties, min_tokens=min_tokens, stop_token_ids=stop_token_ids, + no_penalties=self.no_reqs_with_penalties + ) + + def _construct_prompt_tokens_tensor( + self, requests, vocab_size: int, device: torch.device) \ + -> torch.Tensor: + prompt_token_ids: List[List[int]] = [] + for req_id in self.req_ids[:self.num_reqs]: + assert req_id is not None + request = requests[req_id] + prompt_token_ids.append(request.prompt_token_ids) + prompt_tokens_cpu_tensor = make_tensor_with_pad( + prompt_token_ids, + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory, ) + prompt_tokens_tensor = prompt_tokens_cpu_tensor.to(device=device, + non_blocking=True) + + return prompt_tokens_tensor @property def num_reqs(self) -> int: @@ -883,6 +982,12 @@ def no_top_p(self) -> bool: def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 + @property + def no_reqs_with_penalties(self) -> bool: + return len(self.presence_penalties_reqs) == 0 and \ + len(self.frequency_penalties_reqs) == 0 and \ + len(self.repetition_penalties_reqs) == 0 + @property def max_num_logprobs(self) -> int: return max(self.num_logprobs.values()) if self.num_logprobs else 0 From dc02a4f85ce0f2ce1a45c068b714f695d85abb45 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 5 Dec 2024 14:26:52 +0000 Subject: [PATCH 12/39] Fixes --- vllm/v1/sample/sampler.py | 4 +++- vllm/v1/worker/gpu_model_runner.py | 17 ++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 54ac448e30174..f2ffb5f907d9e 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -220,7 +220,9 @@ def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, """ output_tokens_tensor = make_tensor_with_pad( output_token_ids, - vocab_size, + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + pad=vocab_size, device="cpu", dtype=torch.int64, pin_memory=is_pin_memory_available(), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 67cdb0183751a..7d804da94ff5f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -467,7 +467,6 @@ def execute_model( logits=logits, sampling_metadata=sampling_metadata, ) - # Update the # NOTE: CPU-GPU synchronization happens here. sampled_token_ids = sampler_output.sampled_token_ids.cpu() @@ -690,7 +689,7 @@ def __init__( pin_memory=pin_memory) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() - + # Frequency penalty related data structures self.frequency_penalties = torch.empty((max_num_reqs, vocab_size), dtype=torch.float, device=device) @@ -702,7 +701,7 @@ def __init__( self.frequency_penalties_cpu = \ self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: Set[str] = set() - + # Presence penalty related data structures self.presence_penalties = torch.empty((max_num_reqs, vocab_size), dtype=torch.float, device=device) @@ -714,7 +713,7 @@ def __init__( self.presence_penalties_cpu = \ self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: Set[str] = set() - + # Repetition penalty related data structures self.repetition_penalties = torch.empty((max_num_reqs, vocab_size), dtype=torch.float, device=device) @@ -916,7 +915,13 @@ def make_sampling_metadata( for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None request = requests[req_id] - # Currently we create a tensor from the + # Currently we create a tensor for output_token_ids from scratch + # at each step. However, for the penalties computation what we + # need is stats about the token ids present in the output. This + # stats can be maintained incrementally instead of computing it + # from scratch at each step. + # TODO - Replace this with incremental update to output token + # statistics. output_token_ids.append(request.output_token_ids) min_tokens.append(request.sampling_params.min_tokens) stop_token_ids.append(request.sampling_params.all_stop_token_ids) @@ -952,6 +957,8 @@ def _construct_prompt_tokens_tensor( prompt_token_ids.append(request.prompt_token_ids) prompt_tokens_cpu_tensor = make_tensor_with_pad( prompt_token_ids, + # use the value of vocab_size as a pad since we don't have a + # token_id of this value. pad=vocab_size, device="cpu", dtype=torch.int64, From 0db8e4ff7c847d0b5fd5ed97af3640d33c1d462c Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 02:18:46 +0000 Subject: [PATCH 13/39] Addressing comments --- tests/v1/sample/test_sampler.py | 22 ++++++------ vllm/model_executor/layers/sampler.py | 24 ++++--------- vllm/utils.py | 17 --------- vllm/v1/sample/sampler.py | 24 ++++--------- vllm/v1/worker/gpu_model_runner.py | 50 ++++++++++++++++----------- 5 files changed, 53 insertions(+), 84 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 6b544fe0a3652..82180090a0728 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -21,10 +21,9 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: return fake_logits -def _create_penalty_tensor(batch_size: int, vocab_size: int, - penalty_value: float, +def _create_penalty_tensor(batch_size: int, penalty_value: float, device: torch.device) -> torch.Tensor: - return torch.full((batch_size, vocab_size), + return torch.full((batch_size, 1), fill_value=penalty_value, dtype=torch.float, device=device) @@ -73,12 +72,9 @@ def _create_default_sampling_metadata( prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, - frequency_penalties=_create_penalty_tensor(batch_size, vocab_size, 0.0, - device), - presence_penalties=_create_penalty_tensor(batch_size, vocab_size, 0.0, - device), - repetition_penalties=_create_penalty_tensor(batch_size, vocab_size, - 1.0, device), + frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), + presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), + repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), no_penalties=True, min_tokens=[], stop_token_ids=[], @@ -206,7 +202,7 @@ def test_sampler_presence_penalty(device: str, batch_size: int, NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) output_token_ids = sampling_metadata.output_token_ids sampling_metadata.presence_penalties = _create_penalty_tensor( - batch_size, VOCAB_SIZE, presence_penalty, torch.device(device)) + batch_size, presence_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) @@ -215,6 +211,8 @@ def test_sampler_presence_penalty(device: str, batch_size: int, # Since all tokens initially have the same logprobs, the non-penalized # tokens will appear at the beginning, while the penalized tokens # will appear at the end of the list. + print(' sampler_output.logprob_token_ids ' + + str(sampler_output.logprob_token_ids)) penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ VOCAB_SIZE - 1] penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] @@ -254,7 +252,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) sampling_metadata.frequency_penalties = _create_penalty_tensor( - batch_size, VOCAB_SIZE, frequency_penalty, torch.device(device)) + batch_size, frequency_penalty, torch.device(device)) output_token_ids, sorted_token_ids_in_output = \ _create_weighted_output_token_list(batch_size, VOCAB_SIZE) sampling_metadata.output_token_ids = output_token_ids @@ -306,7 +304,7 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) sampling_metadata.repetition_penalties = _create_penalty_tensor( - batch_size, VOCAB_SIZE, repetition_penalty, torch.device(device)) + batch_size, repetition_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() sampler_output = sampler(fake_logits, sampling_metadata) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb32ed5f3444b..1a5ba89ab2abb 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -11,6 +11,7 @@ import torch.nn as nn import vllm.envs as envs +from vllm.model_executor.layers.utils import apply_penalties from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) @@ -19,7 +20,6 @@ CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.utils import get_token_bin_counts_and_mask if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -389,22 +389,12 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, repetition_penalties: torch.Tensor) -> torch.Tensor: - num_seqs, vocab_size = logits.shape - _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, - vocab_size, num_seqs) - output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - - repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~(prompt_mask | output_mask)] = 1.0 - logits = torch.where(logits > 0, logits / repetition_penalties, - logits * repetition_penalties) - - # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask - return logits + repetition_penalties.unsqueeze_(dim=1) + frequency_penalties.unsqueeze_(dim=1) + presence_penalties.unsqueeze_(dim=1) + return apply_penalties(logits, prompt_tokens_tensor, output_tokens_tensor, + presence_penalties, frequency_penalties, + repetition_penalties) def _apply_top_k_top_p( diff --git a/vllm/utils.py b/vllm/utils.py index f67b8b00bec8e..1f19d9eacd16d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1652,20 +1652,3 @@ def resolve_obj_by_qualname(qualname: str) -> Any: module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, obj_name) - - -def get_token_bin_counts_and_mask( - tokens: torch.Tensor, - vocab_size: int, - num_seqs: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - # Compute the bin counts for the tokens. - # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) - bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) - bin_counts = bin_counts[:, :vocab_size] - mask = bin_counts > 0 - - return bin_counts, mask diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index f2ffb5f907d9e..55465331df0c5 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn -from vllm.utils import (get_token_bin_counts_and_mask, is_pin_memory_available, - make_tensor_with_pad) +from vllm.model_executor.layers.utils import apply_penalties +from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -29,6 +29,7 @@ def forward( sampling_metadata.frequency_penalties, sampling_metadata.repetition_penalties, sampling_metadata.output_token_ids) + print('logits123 ' + str(logits.sort())) logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_top_k_top_p(logits, sampling_metadata) probs = self.get_probs(logits) @@ -194,23 +195,12 @@ def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, """ Applies presence, frequency and repetition penalties to the logits. """ - num_seqs, vocab_size = logits.shape + _, vocab_size = logits.shape output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device) - _, prompt_mask = get_token_bin_counts_and_mask(prompt_token_ids, - vocab_size, num_seqs) - output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_t, vocab_size, num_seqs) - logits[logits > 0] /= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits > 0] - logits[logits <= 0] *= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits <= 0] - # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties * output_bin_counts - logits -= presence_penalties * output_mask - - return logits + return apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5aca11af4302e..8cab133d21cb5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -702,11 +702,11 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, vocab_size), + self.frequency_penalties = torch.empty((max_num_reqs, 1), dtype=torch.float, device=device) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, vocab_size), + (max_num_reqs, 1), dtype=torch.float, device="cpu", pin_memory=pin_memory) @@ -714,23 +714,22 @@ def __init__( self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: Set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, vocab_size), + self.presence_penalties = torch.empty((max_num_reqs, 1), dtype=torch.float, device=device) - self.presence_penalties_cpu_tensor = torch.empty( - (max_num_reqs, vocab_size), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, 1), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) self.presence_penalties_cpu = \ self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: Set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, vocab_size), + self.repetition_penalties = torch.empty((max_num_reqs, 1), dtype=torch.float, device=device) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, vocab_size), + (max_num_reqs, 1), dtype=torch.float, device="cpu", pin_memory=pin_memory) @@ -739,6 +738,10 @@ def __init__( self.repetition_penalties_reqs: Set[str] = set() self.prompt_tokens_tensor: Optional[torch.Tensor] = None + self.min_tokens: List[int] = [0] * max_num_reqs + self.stop_token_ids: List[Set[int]] = [ + set() for _ in range(max_num_reqs) + ] # req_index -> generator self.generators: Dict[int, torch.Generator] = {} @@ -797,6 +800,8 @@ def add_request( sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) + self.min_tokens[req_index] = sampling_params.min_tokens + self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids self.generators[req_index] = request.generator @@ -880,6 +885,9 @@ def condense(self, empty_req_indices: List[int]) -> None: self.presence_penalties_cpu[last_req_index][:] self.repetition_penalties[empty_index][:] = \ self.repetition_penalties[last_req_index][:] + self.min_tokens[empty_index] = self.min_tokens[last_req_index] + self.stop_token_ids[empty_index] = \ + self.stop_token_ids[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: @@ -894,13 +902,14 @@ def make_sampling_metadata( skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: + print('Hello in copy!!!') self.temperature[:self.num_reqs].copy_( self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_p[:self.num_reqs].copy_( self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k[:self.num_reqs].copy_( self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - if not self.no_reqs_with_penalties: + if not self.no_penalties: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. @@ -921,8 +930,6 @@ def make_sampling_metadata( requests, self.vocab_size, device=self.device) output_token_ids: List[List[int]] = [] - min_tokens: List[int] = [] - stop_token_ids: List[set[int]] = [] for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None @@ -935,8 +942,6 @@ def make_sampling_metadata( # TODO - Replace this with incremental update to output token # statistics. output_token_ids.append(request.output_token_ids) - min_tokens.append(request.sampling_params.min_tokens) - stop_token_ids.append(request.sampling_params.all_stop_token_ids) return SamplingMetadata( temperature=self.temperature[:self.num_reqs], @@ -954,14 +959,17 @@ def make_sampling_metadata( presence_penalties=self.presence_penalties[:self.num_reqs], repetition_penalties=self.repetition_penalties[:self.num_reqs], output_token_ids=output_token_ids, - min_tokens=min_tokens, - stop_token_ids=stop_token_ids, - no_penalties=self.no_reqs_with_penalties + min_tokens=self.min_tokens[:self.num_reqs], + stop_token_ids=self.stop_token_ids[:self.num_reqs], + no_penalties=self.no_penalties ) def _construct_prompt_tokens_tensor( - self, requests, vocab_size: int, device: torch.device) \ - -> torch.Tensor: + self, + requests, + vocab_size: int, + device: torch.device, + ) -> torch.Tensor: prompt_token_ids: List[List[int]] = [] for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None @@ -1002,7 +1010,7 @@ def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 @property - def no_reqs_with_penalties(self) -> bool: + def no_penalties(self) -> bool: return len(self.presence_penalties_reqs) == 0 and \ len(self.frequency_penalties_reqs) == 0 and \ len(self.repetition_penalties_reqs) == 0 From f6c416fd3f43bd42853bdd930d0a6ccce311dc59 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 02:25:24 +0000 Subject: [PATCH 14/39] Remove debug prints Signed-off-by: Sourashis Roy --- tests/v1/sample/test_sampler.py | 2 -- vllm/v1/sample/sampler.py | 1 - vllm/v1/worker/gpu_model_runner.py | 1 - 3 files changed, 4 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 82180090a0728..d7c9178b7dca4 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -211,8 +211,6 @@ def test_sampler_presence_penalty(device: str, batch_size: int, # Since all tokens initially have the same logprobs, the non-penalized # tokens will appear at the beginning, while the penalized tokens # will appear at the end of the list. - print(' sampler_output.logprob_token_ids ' + - str(sampler_output.logprob_token_ids)) penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ VOCAB_SIZE - 1] penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 55465331df0c5..eabef8ff56a7b 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -29,7 +29,6 @@ def forward( sampling_metadata.frequency_penalties, sampling_metadata.repetition_penalties, sampling_metadata.output_token_ids) - print('logits123 ' + str(logits.sort())) logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_top_k_top_p(logits, sampling_metadata) probs = self.get_probs(logits) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8cab133d21cb5..89b32df611556 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -902,7 +902,6 @@ def make_sampling_metadata( skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: - print('Hello in copy!!!') self.temperature[:self.num_reqs].copy_( self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_p[:self.num_reqs].copy_( From 034ff3f777af9bcdf77d2259062ccab3ee700e82 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 02:29:47 +0000 Subject: [PATCH 15/39] Adding utils Signed-off-by: Sourashis Roy --- vllm/model_executor/layers/utils.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 vllm/model_executor/layers/utils.py diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py new file mode 100644 index 0000000000000..29b67f2ca8931 --- /dev/null +++ b/vllm/model_executor/layers/utils.py @@ -0,0 +1,54 @@ +"""Utility methods for model layers.""" +import torch +from typing import Tuple + +def get_token_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + """ + Applies penalties in place to the logits tensor + logits : The input logits tensor of shape [num_seqs, vocab_size] + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID + in the vocabulary. output_tokens_tensor: The output tokens tensor. + presence_penalties: The presence penalties of shape [num_seqs, 1] + frequency_penalties: The frequency penalties of shape [num_seqs, 1] + repetition_penalties: The repetition penalties of shape [num_seqs, 1] + """ + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + repetition_penalties = repetition_penalties.repeat( + 1, vocab_size) + logits[logits > 0] /= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits > 0] + logits[logits <= 0] *= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits <= 0] + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties * output_bin_counts + logits -= presence_penalties * output_mask + return logits From 3798152b275330395de37f8a3657e7dcfda4b1ba Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 02:31:11 +0000 Subject: [PATCH 16/39] Fixes Signed-off-by: Sourashis Roy --- vllm/model_executor/layers/utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 29b67f2ca8931..0d77221e7d480 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,7 +1,9 @@ """Utility methods for model layers.""" -import torch from typing import Tuple +import torch + + def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, @@ -20,10 +22,10 @@ def get_token_bin_counts_and_mask( def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: """ Applies penalties in place to the logits tensor logits : The input logits tensor of shape [num_seqs, vocab_size] @@ -41,8 +43,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties.repeat( - 1, vocab_size) + repetition_penalties = repetition_penalties.repeat(1, vocab_size) logits[logits > 0] /= torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)[logits > 0] logits[logits <= 0] *= torch.where(prompt_mask | output_mask, From 00ec97818a85e1d66b7594e314594651a3ad4263 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 02:38:50 +0000 Subject: [PATCH 17/39] More fixes --- vllm/v1/sample/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index eabef8ff56a7b..76a7c9a69bb1d 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -179,7 +179,7 @@ def _apply_min_token_penalties(logits: torch.Tensor, """ min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] for index, min_token in enumerate(min_tokens): - if (min_token > 0 and len(output_token_ids[index]) < min_token): + if (len(output_token_ids[index]) < min_token): for stop_token_id in stop_token_ids[index]: min_tokens_logits_to_penalize.append((index, stop_token_id)) if min_tokens_logits_to_penalize: From cf8728075952c59e2fafce8c02e44b5786017f79 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 03:08:13 +0000 Subject: [PATCH 18/39] Format Signed-off-by: Sourashis Roy --- vllm/model_executor/layers/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 0d77221e7d480..9d1711c0b7b7f 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -33,7 +33,8 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, are padded to the maximum prompt length within the batch using `vocab_size` as the padding value. The value `vocab_size` is used for padding because it does not correspond to any valid token ID - in the vocabulary. output_tokens_tensor: The output tokens tensor. + in the vocabulary. + output_tokens_tensor: The output tokens tensor. presence_penalties: The presence penalties of shape [num_seqs, 1] frequency_penalties: The frequency penalties of shape [num_seqs, 1] repetition_penalties: The repetition penalties of shape [num_seqs, 1] From bde6c9e21a9852075230ad8ed2f94e023d6f5afc Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 06:48:01 +0000 Subject: [PATCH 19/39] Some more tests Signed-off-by: Sourashis Roy --- tests/v1/sample/__init__.py | 0 tests/v1/worker/__init__.py | 0 tests/v1/worker/test_gpu_model_runner.py | 226 +++++++++++++++++++++++ 3 files changed, 226 insertions(+) create mode 100644 tests/v1/sample/__init__.py create mode 100644 tests/v1/worker/__init__.py create mode 100644 tests/v1/worker/test_gpu_model_runner.py diff --git a/tests/v1/sample/__init__.py b/tests/v1/sample/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/worker/__init__.py b/tests/v1/worker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py new file mode 100644 index 0000000000000..c7587d91ce990 --- /dev/null +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -0,0 +1,226 @@ +from typing import List, Set, Tuple, Dict + +import numpy as np +import pytest +import torch + +from vllm.utils import make_tensor_with_pad, is_pin_memory_available +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler +from vllm.v1.worker.gpu_model_runner import InputBatch, CachedRequestState + +VOCAB_SIZE = 1024 +NUM_OUTPUT_TOKENS = 20 +MAX_PROMPT_SIZE = 100 +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] +MAX_NUM_PROMPT_TOKENS = 64 + +def _remove_requests( + input_batch: InputBatch, batch_size: int, + reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]: + """ + Remove some requests randomly from the batch and returns a Tuple + of 1) set of request removed 2) indices of the requests removed + ordered in descending order + """ + + num_reqs_to_remove = np.random.randint(0, batch_size) + req_indices_to_remove: Set[int] = set() + for _ in range(num_reqs_to_remove): + req_index_to_remove = np.random.randint(0, batch_size) + req_indices_to_remove.add(req_index_to_remove) + + req_indices_to_remove_list = list(req_indices_to_remove) + req_indices_to_remove_list.sort(reverse=True) + req_ids_to_remove: Set[str] = set() + for index in req_indices_to_remove: + input_batch.remove_request(reqs[index].req_id) + req_ids_to_remove.add(reqs[index].req_id) + return (req_ids_to_remove, req_indices_to_remove_list) + + +def _construct_expected_sampling_metadata( + reqs: List[CachedRequestState], + req_ids_retained:Set[int], + req_id_index_in_input_batch: Dict[str, int], + device:torch.device) -> SamplingMetadata: + """ + Constructs and returns the expected SamplingMetadata for this + batch. + """ + num_reqs = len(req_ids_retained) + output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)] + prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)] + presence_penalties = [0.0 for _ in range(num_reqs)] + frequency_penalties = [0.0 for _ in range(num_reqs)] + repetition_penalties = [1.0 for _ in range(num_reqs)] + top_k = [0 for _ in range(num_reqs)] + top_p = [0.0 for _ in range(num_reqs)] + temperature = [0.0 for _ in range(num_reqs)] + stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] + min_tokens = [0 for _ in range(num_reqs)] + for req in reqs: + if not req.req_id in req_ids_retained: + continue + index_in_input_batch = req_id_index_in_input_batch[req.req_id] + output_token_ids[index_in_input_batch] = req.output_token_ids + prompt_token_ids[index_in_input_batch] = req.prompt_token_ids + presence_penalties[index_in_input_batch] = req.sampling_params.presence_penalty + frequency_penalties[index_in_input_batch] = req.sampling_params.frequency_penalty + repetition_penalties[index_in_input_batch] = req.sampling_params.repetition_penalty + top_k[index_in_input_batch] = req.sampling_params.top_k + top_p[index_in_input_batch] = req.sampling_params.top_p + temperature[index_in_input_batch] = req.sampling_params.temperature + stop_token_ids[index_in_input_batch] = req.sampling_params.all_stop_token_ids + min_tokens[index_in_input_batch] = req.sampling_params.min_tokens + + + return SamplingMetadata( + temperature=torch.tensor(temperature, dtype=torch.float, device=device), + all_greedy=False, + all_random=True, + top_p=torch.tensor(top_p, dtype=torch.float, device=device), + top_k=torch.tensor(top_k, dtype=torch.int, device=device), + no_top_p=all(x == 1.0 for x in top_p), + no_top_k=all(x == 0 for x in top_k), + generators={}, + max_num_logprobs=0, + prompt_token_ids= make_tensor_with_pad( + prompt_token_ids, + pad=VOCAB_SIZE, + device=torch.device(device), + dtype=torch.int64, + ), + frequency_penalties=torch.tensor( + frequency_penalties, dtype=torch.float, device=device).unsqueeze(dim=1), + presence_penalties=torch.tensor( + presence_penalties, dtype=torch.float, device=device).unsqueeze(dim=1), + repetition_penalties=torch.tensor( + repetition_penalties, dtype=torch.float, device=device).unsqueeze(dim=1), + output_token_ids=output_token_ids, + min_tokens=min_tokens, + stop_token_ids=stop_token_ids, + no_penalties=(all(x ==0 for x in presence_penalties) and \ + all(x ==0 for x in frequency_penalties) and \ + all(x ==1 for x in repetition_penalties)) + ) + +def _create_sampling_params(): + return SamplingParams( + top_k = np.random.randint(1, 10), + top_p = np.random.uniform(0.0, 1.0), + presence_penalty = np.random.uniform(-2.0, 2.0), + repetition_penalty = np.random.uniform(0.0, 2.0), + frequency_penalty = np.random.uniform(-2.0, 2.0), + min_tokens=np.random.randint(1, 10), + stop_token_ids = [np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(10))] + ) + +def _construct_cached_request_state(req_id_suffix:int): + prompt_token_ids = [np.random.randint(0, VOCAB_SIZE) for _ in range( + np.random.randint(0, MAX_PROMPT_SIZE))] + output_token_ids = [np.random.randint(0, VOCAB_SIZE) for _ in range( + np.random.randint(0, NUM_OUTPUT_TOKENS))] + return CachedRequestState( + req_id=f"req_id_{req_id_suffix}", + prompt_token_ids=prompt_token_ids, + prompt=None, + sampling_params=_create_sampling_params(), + mm_inputs=[], + mm_positions=[], + block_ids=[], + generator=None, + num_computed_tokens=len(output_token_ids), + output_token_ids=output_token_ids + ) + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32, 64]) +def test_sampling_metadata_in_input_batch( + device: str, batch_size: int): + """ + Tests the logic for managing sampling metadata in the InputBatch. + + This test involves adding a set of requests to the InputBatch, + followed by removing a subset of them. Afterward, the batch is compacted, + and the `make_sampling_metadata` method is invoked on the batch. The + output of `make_sampling_metadata` is then compared against the expected + results to ensure correctness. + """ + input_batch : InputBatch = InputBatch( + max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024) + reqs : List[CachedRequestState] = [] + req_id_reqs = {} + # Add requests + for req_index in range(batch_size): + req : CachedRequestState = _construct_cached_request_state(req_index) + input_batch.add_request(req, req_index) + reqs.append(req) + req_id_reqs[req.req_id] = req + + + # Remove some requests + req_ids_to_remove, req_indices_to_remove = _remove_requests( + input_batch, batch_size, reqs) + req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove + + # Compact the input batch + input_batch.condense(req_indices_to_remove) + + # Generate the sampling metadata + sampling_metadata = input_batch.make_sampling_metadata( + req_id_reqs, skip_copy=False) + + # Create expected output. + expected_sampling_metadata = _construct_expected_sampling_metadata( + reqs, req_ids_retained, + input_batch.req_id_to_index, device=torch.device(device) + ) + # Assert the actual and expected output. + assert torch.allclose( + expected_sampling_metadata.temperature, sampling_metadata.temperature) + assert torch.allclose(expected_sampling_metadata.top_p, sampling_metadata.top_p) + assert torch.allclose( + expected_sampling_metadata.top_k, sampling_metadata.top_k) + assert torch.allclose( + expected_sampling_metadata.frequency_penalties, + sampling_metadata.frequency_penalties) + assert torch.allclose( + expected_sampling_metadata.presence_penalties, + sampling_metadata.presence_penalties) + assert torch.allclose( + expected_sampling_metadata.repetition_penalties, + sampling_metadata.repetition_penalties) + assert torch.allclose( + expected_sampling_metadata.prompt_token_ids, + sampling_metadata.prompt_token_ids) + assert ( + expected_sampling_metadata.output_token_ids == sampling_metadata.output_token_ids) + assert ( + expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens) + assert ( + expected_sampling_metadata.stop_token_ids == sampling_metadata.stop_token_ids) + assert ( + expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties) + assert ( + expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p) + assert ( + expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k) + + + + + + + + + + From a46cd14f536a47ab16e8fa659bd96d4e617cab31 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 9 Dec 2024 06:51:25 +0000 Subject: [PATCH 20/39] Format Signed-off-by: Sourashis Roy --- tests/v1/worker/test_gpu_model_runner.py | 207 +++++++++++------------ 1 file changed, 101 insertions(+), 106 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index c7587d91ce990..9741c06eecc51 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,14 +1,13 @@ -from typing import List, Set, Tuple, Dict +from typing import Dict, List, Set, Tuple import numpy as np import pytest import torch -from vllm.utils import make_tensor_with_pad, is_pin_memory_available from vllm.sampling_params import SamplingParams +from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.sampler import Sampler -from vllm.v1.worker.gpu_model_runner import InputBatch, CachedRequestState +from vllm.v1.worker.gpu_model_runner import CachedRequestState, InputBatch VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 @@ -18,9 +17,10 @@ ] MAX_NUM_PROMPT_TOKENS = 64 + def _remove_requests( - input_batch: InputBatch, batch_size: int, - reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]: + input_batch: InputBatch, batch_size: int, + reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]: """ Remove some requests randomly from the batch and returns a Tuple of 1) set of request removed 2) indices of the requests removed @@ -32,7 +32,7 @@ def _remove_requests( for _ in range(num_reqs_to_remove): req_index_to_remove = np.random.randint(0, batch_size) req_indices_to_remove.add(req_index_to_remove) - + req_indices_to_remove_list = list(req_indices_to_remove) req_indices_to_remove_list.sort(reverse=True) req_ids_to_remove: Set[str] = set() @@ -43,10 +43,9 @@ def _remove_requests( def _construct_expected_sampling_metadata( - reqs: List[CachedRequestState], - req_ids_retained:Set[int], - req_id_index_in_input_batch: Dict[str, int], - device:torch.device) -> SamplingMetadata: + reqs: List[CachedRequestState], req_ids_retained: Set[int], + req_id_index_in_input_batch: Dict[str, int], + device: torch.device) -> SamplingMetadata: """ Constructs and returns the expected SamplingMetadata for this batch. @@ -63,21 +62,25 @@ def _construct_expected_sampling_metadata( stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] min_tokens = [0 for _ in range(num_reqs)] for req in reqs: - if not req.req_id in req_ids_retained: + if req.req_id not in req_ids_retained: continue - index_in_input_batch = req_id_index_in_input_batch[req.req_id] + index_in_input_batch = req_id_index_in_input_batch[req.req_id] output_token_ids[index_in_input_batch] = req.output_token_ids prompt_token_ids[index_in_input_batch] = req.prompt_token_ids - presence_penalties[index_in_input_batch] = req.sampling_params.presence_penalty - frequency_penalties[index_in_input_batch] = req.sampling_params.frequency_penalty - repetition_penalties[index_in_input_batch] = req.sampling_params.repetition_penalty + presence_penalties[ + index_in_input_batch] = req.sampling_params.presence_penalty + frequency_penalties[ + index_in_input_batch] = req.sampling_params.frequency_penalty + repetition_penalties[ + index_in_input_batch] = req.sampling_params.repetition_penalty top_k[index_in_input_batch] = req.sampling_params.top_k top_p[index_in_input_batch] = req.sampling_params.top_p temperature[index_in_input_batch] = req.sampling_params.temperature - stop_token_ids[index_in_input_batch] = req.sampling_params.all_stop_token_ids + stop_token_ids[ + index_in_input_batch] = req.sampling_params.all_stop_token_ids min_tokens[index_in_input_batch] = req.sampling_params.min_tokens - + return SamplingMetadata( temperature=torch.tensor(temperature, dtype=torch.float, device=device), all_greedy=False, @@ -95,11 +98,14 @@ def _construct_expected_sampling_metadata( dtype=torch.int64, ), frequency_penalties=torch.tensor( - frequency_penalties, dtype=torch.float, device=device).unsqueeze(dim=1), + frequency_penalties, dtype=torch.float, + device=device).unsqueeze(dim=1), presence_penalties=torch.tensor( - presence_penalties, dtype=torch.float, device=device).unsqueeze(dim=1), + presence_penalties, dtype=torch.float, + device=device).unsqueeze(dim=1), repetition_penalties=torch.tensor( - repetition_penalties, dtype=torch.float, device=device).unsqueeze(dim=1), + repetition_penalties, dtype=torch.float, + device=device).unsqueeze(dim=1), output_token_ids=output_token_ids, min_tokens=min_tokens, stop_token_ids=stop_token_ids, @@ -108,39 +114,44 @@ def _construct_expected_sampling_metadata( all(x ==1 for x in repetition_penalties)) ) + def _create_sampling_params(): - return SamplingParams( - top_k = np.random.randint(1, 10), - top_p = np.random.uniform(0.0, 1.0), - presence_penalty = np.random.uniform(-2.0, 2.0), - repetition_penalty = np.random.uniform(0.0, 2.0), - frequency_penalty = np.random.uniform(-2.0, 2.0), - min_tokens=np.random.randint(1, 10), - stop_token_ids = [np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(10))] - ) + return SamplingParams(top_k=np.random.randint(1, 10), + top_p=np.random.uniform(0.0, 1.0), + presence_penalty=np.random.uniform(-2.0, 2.0), + repetition_penalty=np.random.uniform(0.0, 2.0), + frequency_penalty=np.random.uniform(-2.0, 2.0), + min_tokens=np.random.randint(1, 10), + stop_token_ids=[ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(10)) + ]) + + +def _construct_cached_request_state(req_id_suffix: int): + prompt_token_ids = [ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(0, MAX_PROMPT_SIZE)) + ] + output_token_ids = [ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) + ] + return CachedRequestState(req_id=f"req_id_{req_id_suffix}", + prompt_token_ids=prompt_token_ids, + prompt=None, + sampling_params=_create_sampling_params(), + mm_inputs=[], + mm_positions=[], + block_ids=[], + generator=None, + num_computed_tokens=len(output_token_ids), + output_token_ids=output_token_ids) -def _construct_cached_request_state(req_id_suffix:int): - prompt_token_ids = [np.random.randint(0, VOCAB_SIZE) for _ in range( - np.random.randint(0, MAX_PROMPT_SIZE))] - output_token_ids = [np.random.randint(0, VOCAB_SIZE) for _ in range( - np.random.randint(0, NUM_OUTPUT_TOKENS))] - return CachedRequestState( - req_id=f"req_id_{req_id_suffix}", - prompt_token_ids=prompt_token_ids, - prompt=None, - sampling_params=_create_sampling_params(), - mm_inputs=[], - mm_positions=[], - block_ids=[], - generator=None, - num_computed_tokens=len(output_token_ids), - output_token_ids=output_token_ids - ) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32, 64]) -def test_sampling_metadata_in_input_batch( - device: str, batch_size: int): +def test_sampling_metadata_in_input_batch(device: str, batch_size: int): """ Tests the logic for managing sampling metadata in the InputBatch. @@ -150,77 +161,61 @@ def test_sampling_metadata_in_input_batch( output of `make_sampling_metadata` is then compared against the expected results to ensure correctness. """ - input_batch : InputBatch = InputBatch( - max_num_reqs=batch_size, - max_model_len=1024, - max_num_blocks_per_req=10, - device=torch.device(device), - pin_memory=is_pin_memory_available(), - vocab_size=1024) - reqs : List[CachedRequestState] = [] + input_batch: InputBatch = InputBatch(max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024) + reqs: List[CachedRequestState] = [] req_id_reqs = {} # Add requests for req_index in range(batch_size): - req : CachedRequestState = _construct_cached_request_state(req_index) + req: CachedRequestState = _construct_cached_request_state(req_index) input_batch.add_request(req, req_index) reqs.append(req) req_id_reqs[req.req_id] = req - - + # Remove some requests req_ids_to_remove, req_indices_to_remove = _remove_requests( input_batch, batch_size, reqs) req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove - + # Compact the input batch input_batch.condense(req_indices_to_remove) - + # Generate the sampling metadata - sampling_metadata = input_batch.make_sampling_metadata( - req_id_reqs, skip_copy=False) - + sampling_metadata = input_batch.make_sampling_metadata(req_id_reqs, + skip_copy=False) + # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( - reqs, req_ids_retained, - input_batch.req_id_to_index, device=torch.device(device) - ) + reqs, + req_ids_retained, + input_batch.req_id_to_index, + device=torch.device(device)) # Assert the actual and expected output. - assert torch.allclose( - expected_sampling_metadata.temperature, sampling_metadata.temperature) - assert torch.allclose(expected_sampling_metadata.top_p, sampling_metadata.top_p) - assert torch.allclose( - expected_sampling_metadata.top_k, sampling_metadata.top_k) - assert torch.allclose( - expected_sampling_metadata.frequency_penalties, - sampling_metadata.frequency_penalties) - assert torch.allclose( - expected_sampling_metadata.presence_penalties, - sampling_metadata.presence_penalties) - assert torch.allclose( - expected_sampling_metadata.repetition_penalties, - sampling_metadata.repetition_penalties) - assert torch.allclose( - expected_sampling_metadata.prompt_token_ids, - sampling_metadata.prompt_token_ids) + assert torch.allclose(expected_sampling_metadata.temperature, + sampling_metadata.temperature) + assert torch.allclose(expected_sampling_metadata.top_p, + sampling_metadata.top_p) + assert torch.allclose(expected_sampling_metadata.top_k, + sampling_metadata.top_k) + assert torch.allclose(expected_sampling_metadata.frequency_penalties, + sampling_metadata.frequency_penalties) + assert torch.allclose(expected_sampling_metadata.presence_penalties, + sampling_metadata.presence_penalties) + assert torch.allclose(expected_sampling_metadata.repetition_penalties, + sampling_metadata.repetition_penalties) + assert torch.allclose(expected_sampling_metadata.prompt_token_ids, + sampling_metadata.prompt_token_ids) + assert (expected_sampling_metadata.output_token_ids == + sampling_metadata.output_token_ids) assert ( - expected_sampling_metadata.output_token_ids == sampling_metadata.output_token_ids) - assert ( - expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens) - assert ( - expected_sampling_metadata.stop_token_ids == sampling_metadata.stop_token_ids) - assert ( - expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties) - assert ( - expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p) - assert ( - expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k) - - - - - - - - - - + expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens) + assert (expected_sampling_metadata.stop_token_ids == + sampling_metadata.stop_token_ids) + assert (expected_sampling_metadata.no_penalties == + sampling_metadata.no_penalties) + assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p) + assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k) From 795b1f857d42a8a2f44587ff6df931d935eddb63 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 10 Dec 2024 06:56:10 +0000 Subject: [PATCH 21/39] Merge --- vllm/v1/worker/gpu_input_batch.py | 150 +++++++++++ vllm/v1/worker/gpu_model_runner.py | 416 +---------------------------- 2 files changed, 151 insertions(+), 415 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 25d95ac6e26af..f88044cd949e4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -8,6 +8,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata if TYPE_CHECKING: @@ -43,12 +44,14 @@ def __init__( max_num_blocks_per_req: int, device: torch.device, pin_memory: bool, + vocab_size: int, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory + self.vocab_size = vocab_size self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_id_to_index: Dict[str, int] = {} @@ -101,6 +104,50 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() + # Frequency penalty related data structures + self.frequency_penalties = torch.empty((max_num_reqs, 1), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, 1), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_reqs: Set[str] = set() + + # Presence penalty related data structures + self.presence_penalties = torch.empty((max_num_reqs, 1), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, 1), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = \ + self.presence_penalties_cpu_tensor.numpy() + self.presence_penalties_reqs: Set[str] = set() + + # Repetition penalty related data structures + self.repetition_penalties = torch.empty((max_num_reqs, 1), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, 1), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu =\ + self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_reqs: Set[str] = set() + + self.prompt_tokens_tensor: Optional[torch.Tensor] = None + self.min_tokens: List[int] = [0] * max_num_reqs + self.stop_token_ids: List[Set[int]] = [ + set() for _ in range(max_num_reqs) + ] + # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -148,6 +195,20 @@ def add_request( self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) + self.frequency_penalties_cpu[req_index][:] =\ + sampling_params.frequency_penalty + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[req_index][:] = \ + sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[req_index][:] = \ + sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + self.min_tokens[req_index] = sampling_params.min_tokens + self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -170,6 +231,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) @@ -182,6 +246,9 @@ def clear(self) -> None: self.random_reqs.clear() self.top_p_reqs.clear() self.top_k_reqs.clear() + self.frequency_penalties_reqs.clear() + self.presence_penalties_reqs.clear() + self.repetition_penalties_reqs.clear() self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() @@ -222,6 +289,15 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index][:] = \ + self.frequency_penalties_cpu[last_req_index][:] + self.presence_penalties_cpu[empty_index][:] = \ + self.presence_penalties_cpu[last_req_index][:] + self.repetition_penalties_cpu[empty_index][:] = \ + self.repetition_penalties_cpu[last_req_index][:] + self.min_tokens[empty_index] = self.min_tokens[last_req_index] + self.stop_token_ids[empty_index] = \ + self.stop_token_ids[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator @@ -231,6 +307,7 @@ def condense(self, empty_req_indices: List[int]) -> None: def make_sampling_metadata( self, + requests: Dict[str, CachedRequestState], skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: @@ -240,6 +317,40 @@ def make_sampling_metadata( self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k[:self.num_reqs].copy_( self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + self.frequency_penalties[:self.num_reqs].copy_( + self.frequency_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.presence_penalties[:self.num_reqs].copy_( + self.presence_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.repetition_penalties[:self.num_reqs].copy_( + self.repetition_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + self.prompt_tokens_tensor = \ + self._construct_prompt_tokens_tensor( + requests, self.vocab_size, device=self.device) + + output_token_ids: List[List[int]] = [] + + for req_id in self.req_ids[:self.num_reqs]: + assert req_id is not None + request = requests[req_id] + # Currently we create a tensor for output_token_ids from scratch + # at each step. However, for the penalties computation what we + # need is stats about the token ids present in the output. This + # stats can be maintained incrementally instead of computing it + # from scratch at each step. + # TODO - Replace this with incremental update to output token + # statistics. + output_token_ids.append(request.output_token_ids) + return SamplingMetadata( temperature=self.temperature[:self.num_reqs], all_greedy=self.all_greedy, @@ -250,7 +361,40 @@ def make_sampling_metadata( no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=self.prompt_tokens_tensor[:self.num_reqs] \ + if self.prompt_tokens_tensor is not None else None, + frequency_penalties=self.frequency_penalties[:self.num_reqs], + presence_penalties=self.presence_penalties[:self.num_reqs], + repetition_penalties=self.repetition_penalties[:self.num_reqs], + output_token_ids=output_token_ids, + min_tokens=self.min_tokens[:self.num_reqs], + stop_token_ids=self.stop_token_ids[:self.num_reqs], + no_penalties=self.no_penalties + ) + + def _construct_prompt_tokens_tensor( + self, + requests, + vocab_size: int, + device: torch.device, + ) -> torch.Tensor: + prompt_token_ids: List[List[int]] = [] + for req_id in self.req_ids[:self.num_reqs]: + assert req_id is not None + request = requests[req_id] + prompt_token_ids.append(request.prompt_token_ids) + prompt_tokens_cpu_tensor = make_tensor_with_pad( + prompt_token_ids, + # use the value of vocab_size as a pad since we don't have a + # token_id of this value. + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory, ) + prompt_tokens_tensor = prompt_tokens_cpu_tensor.to(device=device, + non_blocking=True) + return prompt_tokens_tensor @property def num_reqs(self) -> int: @@ -272,6 +416,12 @@ def no_top_p(self) -> bool: def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 + @property + def no_penalties(self) -> bool: + return len(self.presence_penalties_reqs) == 0 and \ + len(self.frequency_penalties_reqs) == 0 and \ + len(self.repetition_penalties_reqs) == 0 + @property def max_num_logprobs(self) -> int: return max(self.num_logprobs.values()) if self.num_logprobs else 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5e4b76b1741fc..3943d67b199e1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -16,7 +16,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, - is_pin_memory_available, make_tensor_with_pad) + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput @@ -613,417 +613,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: if batch_size <= size: return size return None - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - self.vocab_size = vocab_size - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, 1), - dtype=torch.float, - device=device) - self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, 1), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() - self.frequency_penalties_reqs: Set[str] = set() - # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, 1), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, 1), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = \ - self.presence_penalties_cpu_tensor.numpy() - self.presence_penalties_reqs: Set[str] = set() - # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, 1), - dtype=torch.float, - device=device) - self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, 1), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu =\ - self.repetition_penalties_cpu_tensor.numpy() - self.repetition_penalties_reqs: Set[str] = set() - - self.prompt_tokens_tensor: Optional[torch.Tensor] = None - self.min_tokens: List[int] = [0] * max_num_reqs - self.stop_token_ids: List[Set[int]] = [ - set() for _ in range(max_num_reqs) - ] - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - self.frequency_penalties_cpu[req_index][:] =\ - sampling_params.frequency_penalty - if sampling_params.frequency_penalty != 0.0: - self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[req_index][:] = \ - sampling_params.presence_penalty - if sampling_params.presence_penalty != 0.0: - self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[req_index][:] = \ - sampling_params.repetition_penalty - if sampling_params.repetition_penalty != 1.0: - self.repetition_penalties_reqs.add(req_id) - self.min_tokens[req_index] = sampling_params.min_tokens - self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.frequency_penalties_reqs.discard(req_id) - self.presence_penalties_reqs.discard(req_id) - self.repetition_penalties_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.frequency_penalties_reqs.clear() - self.presence_penalties_reqs.clear() - self.repetition_penalties_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[empty_index][:] = \ - self.frequency_penalties_cpu[last_req_index][:] - self.presence_penalties_cpu[empty_index][:] = \ - self.presence_penalties_cpu[last_req_index][:] - self.repetition_penalties[empty_index][:] = \ - self.repetition_penalties[last_req_index][:] - self.min_tokens[empty_index] = self.min_tokens[last_req_index] - self.stop_token_ids[empty_index] = \ - self.stop_token_ids[last_req_index] - - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - requests: Dict[str, CachedRequestState], - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - if not self.no_penalties: - # Since syncing these tensors is expensive only copy them - # if necessary i.e. if there are requests which require - # penalties to be applied during sampling. - self.frequency_penalties[:self.num_reqs].copy_( - self.frequency_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True) - self.presence_penalties[:self.num_reqs].copy_( - self.presence_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True) - self.repetition_penalties[:self.num_reqs].copy_( - self.repetition_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True) - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. - self.prompt_tokens_tensor = \ - self._construct_prompt_tokens_tensor( - requests, self.vocab_size, device=self.device) - - output_token_ids: List[List[int]] = [] - - for req_id in self.req_ids[:self.num_reqs]: - assert req_id is not None - request = requests[req_id] - # Currently we create a tensor for output_token_ids from scratch - # at each step. However, for the penalties computation what we - # need is stats about the token ids present in the output. This - # stats can be maintained incrementally instead of computing it - # from scratch at each step. - # TODO - Replace this with incremental update to output token - # statistics. - output_token_ids.append(request.output_token_ids) - - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=self.prompt_tokens_tensor[:self.num_reqs] \ - if self.prompt_tokens_tensor is not None else None, - frequency_penalties=self.frequency_penalties[:self.num_reqs], - presence_penalties=self.presence_penalties[:self.num_reqs], - repetition_penalties=self.repetition_penalties[:self.num_reqs], - output_token_ids=output_token_ids, - min_tokens=self.min_tokens[:self.num_reqs], - stop_token_ids=self.stop_token_ids[:self.num_reqs], - no_penalties=self.no_penalties - ) - - def _construct_prompt_tokens_tensor( - self, - requests, - vocab_size: int, - device: torch.device, - ) -> torch.Tensor: - prompt_token_ids: List[List[int]] = [] - for req_id in self.req_ids[:self.num_reqs]: - assert req_id is not None - request = requests[req_id] - prompt_token_ids.append(request.prompt_token_ids) - prompt_tokens_cpu_tensor = make_tensor_with_pad( - prompt_token_ids, - # use the value of vocab_size as a pad since we don't have a - # token_id of this value. - pad=vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=self.pin_memory, - ) - prompt_tokens_tensor = prompt_tokens_cpu_tensor.to(device=device, - non_blocking=True) - - return prompt_tokens_tensor - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def no_penalties(self) -> bool: - return len(self.presence_penalties_reqs) == 0 and \ - len(self.frequency_penalties_reqs) == 0 and \ - len(self.repetition_penalties_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 From 9a1ab49dee5b0f2dddf0704e408faf6d580f8d51 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 10 Dec 2024 06:58:41 +0000 Subject: [PATCH 22/39] Rename test file --- .../worker/{test_gpu_model_runner.py => test_gpu_input_batch.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/v1/worker/{test_gpu_model_runner.py => test_gpu_input_batch.py} (100%) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_input_batch.py similarity index 100% rename from tests/v1/worker/test_gpu_model_runner.py rename to tests/v1/worker/test_gpu_input_batch.py From 6472b705fd0f9ef936b0bc4760fb4ff460907c4f Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 10 Dec 2024 07:05:11 +0000 Subject: [PATCH 23/39] Tests Signed-off-by: Sourashis Roy --- tests/v1/worker/test_gpu_input_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 9741c06eecc51..cc2b98c6faac9 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -7,7 +7,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_model_runner import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 From 239a3fd78fd84fa03044c6896e129d20d1a7dd84 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 10 Dec 2024 08:04:58 +0000 Subject: [PATCH 24/39] Changes Signed-off-by: Sourashis Roy --- tests/v1/sample/test_sampler.py | 2 +- tests/v1/worker/test_gpu_input_batch.py | 7 +++--- vllm/model_executor/layers/sampler.py | 23 ++++-------------- vllm/model_executor/layers/utils.py | 7 +++--- vllm/v1/worker/gpu_input_batch.py | 32 ++++++++++++------------- 5 files changed, 30 insertions(+), 41 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index d7c9178b7dca4..d8d055805cbea 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -23,7 +23,7 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: def _create_penalty_tensor(batch_size: int, penalty_value: float, device: torch.device) -> torch.Tensor: - return torch.full((batch_size, 1), + return torch.full((batch_size, ), fill_value=penalty_value, dtype=torch.float, device=device) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index cc2b98c6faac9..4c8141bab8e71 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -99,13 +99,13 @@ def _construct_expected_sampling_metadata( ), frequency_penalties=torch.tensor( frequency_penalties, dtype=torch.float, - device=device).unsqueeze(dim=1), + device=device), presence_penalties=torch.tensor( presence_penalties, dtype=torch.float, - device=device).unsqueeze(dim=1), + device=device), repetition_penalties=torch.tensor( repetition_penalties, dtype=torch.float, - device=device).unsqueeze(dim=1), + device=device), output_token_ids=output_token_ids, min_tokens=min_tokens, stop_token_ids=stop_token_ids, @@ -194,6 +194,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): req_ids_retained, input_batch.req_id_to_index, device=torch.device(device)) + # Assert the actual and expected output. assert torch.allclose(expected_sampling_metadata.temperature, sampling_metadata.temperature) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1a5ba89ab2abb..c2d12c466ba45 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -259,11 +259,11 @@ def forward( # Apply presence and frequency penalties. if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) + logits = apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. @@ -384,19 +384,6 @@ def _apply_min_tokens_penalty( return logits -def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: - repetition_penalties.unsqueeze_(dim=1) - frequency_penalties.unsqueeze_(dim=1) - presence_penalties.unsqueeze_(dim=1) - return apply_penalties(logits, prompt_tokens_tensor, output_tokens_tensor, - presence_penalties, frequency_penalties, - repetition_penalties) - - def _apply_top_k_top_p( logits: torch.Tensor, p: torch.Tensor, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 9d1711c0b7b7f..4fc4880290e0b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -44,13 +44,14 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties.repeat(1, vocab_size) + repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat( + 1, vocab_size) logits[logits > 0] /= torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)[logits > 0] logits[logits <= 0] *= torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)[logits <= 0] # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties * output_bin_counts - logits -= presence_penalties * output_mask + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index f88044cd949e4..2098468db86eb 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -105,11 +105,11 @@ def __init__( self.top_k_reqs: Set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, 1), + self.frequency_penalties = torch.empty((max_num_reqs, ), dtype=torch.float, device=device) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, 1), + (max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) @@ -118,10 +118,10 @@ def __init__( self.frequency_penalties_reqs: Set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, 1), + self.presence_penalties = torch.empty((max_num_reqs, ), dtype=torch.float, device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, 1), + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) @@ -130,11 +130,11 @@ def __init__( self.presence_penalties_reqs: Set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, 1), + self.repetition_penalties = torch.empty((max_num_reqs, ), dtype=torch.float, device=device) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, 1), + (max_num_reqs, ), dtype=torch.float, device="cpu", pin_memory=pin_memory) @@ -195,15 +195,15 @@ def add_request( self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) - self.frequency_penalties_cpu[req_index][:] =\ + self.frequency_penalties_cpu[req_index] =\ sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[req_index][:] = \ + self.presence_penalties_cpu[req_index] = \ sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[req_index][:] = \ + self.repetition_penalties_cpu[req_index] = \ sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) @@ -289,12 +289,12 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[empty_index][:] = \ - self.frequency_penalties_cpu[last_req_index][:] - self.presence_penalties_cpu[empty_index][:] = \ - self.presence_penalties_cpu[last_req_index][:] - self.repetition_penalties_cpu[empty_index][:] = \ - self.repetition_penalties_cpu[last_req_index][:] + self.frequency_penalties_cpu[empty_index] = \ + self.frequency_penalties_cpu[last_req_index] + self.presence_penalties_cpu[empty_index] = \ + self.presence_penalties_cpu[last_req_index] + self.repetition_penalties_cpu[empty_index] = \ + self.repetition_penalties_cpu[last_req_index] self.min_tokens[empty_index] = self.min_tokens[last_req_index] self.stop_token_ids[empty_index] = \ self.stop_token_ids[last_req_index] @@ -369,7 +369,7 @@ def make_sampling_metadata( output_token_ids=output_token_ids, min_tokens=self.min_tokens[:self.num_reqs], stop_token_ids=self.stop_token_ids[:self.num_reqs], - no_penalties=self.no_penalties + no_penalties=self.no_penalties, ) def _construct_prompt_tokens_tensor( From 0e3179a816dcffc1a81eda7e09acc305b5775a17 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 10 Dec 2024 09:45:17 +0000 Subject: [PATCH 25/39] Comments Signed-off-by: Sourashis Roy --- vllm/model_executor/layers/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 4fc4880290e0b..f6f34cd49d953 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -35,9 +35,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, for padding because it does not correspond to any valid token ID in the vocabulary. output_tokens_tensor: The output tokens tensor. - presence_penalties: The presence penalties of shape [num_seqs, 1] - frequency_penalties: The frequency penalties of shape [num_seqs, 1] - repetition_penalties: The repetition penalties of shape [num_seqs, 1] + presence_penalties: The presence penalties of shape (num_seqs, ) + frequency_penalties: The frequency penalties of shape (num_seqs, ) + repetition_penalties: The repetition penalties of shape (num_seqs, ) """ num_seqs, vocab_size = logits.shape _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, From 09a73d01f69c2a793d566b5c20db2ba0dff9db82 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 17 Dec 2024 11:17:37 +0000 Subject: [PATCH 26/39] Only pass output token_ids to sampler --- tests/v1/engine/test_engine_core.py | 3 ++- tests/v1/worker/test_gpu_input_batch.py | 6 ++++-- vllm/v1/worker/gpu_input_batch.py | 19 ++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 6 +++++- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index acf3823cf1444..c529cd21f384b 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -151,7 +151,8 @@ def test_engine_core_advanced_sampling(monkeypatch): m.setenv("VLLM_USE_V1", "1") """Setup the EngineCore.""" engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 4c8141bab8e71..694ce81ff6e22 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -169,12 +169,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): vocab_size=1024) reqs: List[CachedRequestState] = [] req_id_reqs = {} + req_id_output_token_ids = {} # Add requests for req_index in range(batch_size): req: CachedRequestState = _construct_cached_request_state(req_index) input_batch.add_request(req, req_index) reqs.append(req) req_id_reqs[req.req_id] = req + req_id_output_token_ids[req.req_id] = req.output_token_ids # Remove some requests req_ids_to_remove, req_indices_to_remove = _remove_requests( @@ -185,8 +187,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch.condense(req_indices_to_remove) # Generate the sampling metadata - sampling_metadata = input_batch.make_sampling_metadata(req_id_reqs, - skip_copy=False) + sampling_metadata = input_batch.make_sampling_metadata( + req_id_output_token_ids, skip_copy=False) # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c5b728ccac16d..686cd443d9e68 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -66,6 +66,7 @@ def __init__( ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_prompt_token_ids = np.empty(max_num_reqs, dtype=np.int32) # Attention-related. self.block_table = torch.zeros( @@ -180,6 +181,7 @@ def add_request( # Copy the prompt token ids and output token ids. num_prompt_tokens = len(request.prompt_token_ids) + self.num_prompt_token_ids[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens @@ -291,6 +293,8 @@ def condense(self, empty_req_indices: List[int]) -> None: # block_table_cpu. self.token_ids_cpu[empty_index] = self.token_ids_cpu[ last_req_index] + self.num_prompt_token_ids[empty_index] =\ + self.num_prompt_token_ids[last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table_cpu[empty_index] = self.block_table_cpu[ @@ -317,7 +321,7 @@ def condense(self, empty_req_indices: List[int]) -> None: def make_sampling_metadata( self, - requests: Dict[str, CachedRequestState], + req_id_output_token_ids: Dict[str, List[int]], skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: @@ -345,13 +349,12 @@ def make_sampling_metadata( # there are requests which need penalties to be applied. self.prompt_tokens_tensor = \ self._construct_prompt_tokens_tensor( - requests, self.vocab_size, device=self.device) + self.vocab_size, device=self.device) output_token_ids: List[List[int]] = [] for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None - request = requests[req_id] # Currently we create a tensor for output_token_ids from scratch # at each step. However, for the penalties computation what we # need is stats about the token ids present in the output. This @@ -359,7 +362,7 @@ def make_sampling_metadata( # from scratch at each step. # TODO - Replace this with incremental update to output token # statistics. - output_token_ids.append(request.output_token_ids) + output_token_ids.append(req_id_output_token_ids[req_id]) return SamplingMetadata( temperature=self.temperature[:self.num_reqs], @@ -384,15 +387,13 @@ def make_sampling_metadata( def _construct_prompt_tokens_tensor( self, - requests, vocab_size: int, device: torch.device, ) -> torch.Tensor: prompt_token_ids: List[List[int]] = [] - for req_id in self.req_ids[:self.num_reqs]: - assert req_id is not None - request = requests[req_id] - prompt_token_ids.append(request.prompt_token_ids) + for index in range(self.num_reqs): + prompt_token_ids.append(self.token_ids_cpu[ + index, :self.num_prompt_token_ids[index]].tolist()) prompt_tokens_cpu_tensor = make_tensor_with_pad( prompt_token_ids, # use the value of vocab_size as a pad since we don't have a diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 48f0c33819644..b1fe3c3e20c88 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -365,8 +365,12 @@ def _prepare_sampling( or scheduler_output.scheduled_resumed_reqs): skip_copy = False # Create the sampling metadata. + req_id_output_token_ids: Dict[str, List[int]] = \ + {req_id: req.output_token_ids \ + for req_id, req in self.requests.items()} + sampling_metadata = self.input_batch.make_sampling_metadata( - self.requests, skip_copy) + req_id_output_token_ids, skip_copy) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): From abda623cf42697ef01923c222c9f96fb3b4148a7 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 17 Dec 2024 13:36:56 +0000 Subject: [PATCH 27/39] Dummy --- vllm/v1/worker/gpu_input_batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 686cd443d9e68..32172c8fb7f55 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -391,6 +391,7 @@ def _construct_prompt_tokens_tensor( device: torch.device, ) -> torch.Tensor: prompt_token_ids: List[List[int]] = [] + for index in range(self.num_reqs): prompt_token_ids.append(self.token_ids_cpu[ index, :self.num_prompt_token_ids[index]].tolist()) From c1d6cd1760ef04650c5bd4ba497c6146f3b90ce9 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 17 Dec 2024 13:37:57 +0000 Subject: [PATCH 28/39] Rerun tests --- vllm/v1/worker/gpu_input_batch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 32172c8fb7f55..686cd443d9e68 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -391,7 +391,6 @@ def _construct_prompt_tokens_tensor( device: torch.device, ) -> torch.Tensor: prompt_token_ids: List[List[int]] = [] - for index in range(self.num_reqs): prompt_token_ids.append(self.token_ids_cpu[ index, :self.num_prompt_token_ids[index]].tolist()) From 6861e9714aeece6269ec5839be92d7d932d6a0d8 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 17 Dec 2024 15:04:40 +0000 Subject: [PATCH 29/39] Remove tolist for prompts --- vllm/v1/worker/gpu_input_batch.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 686cd443d9e68..71abd69da5642 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -8,7 +8,6 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata if TYPE_CHECKING: @@ -390,19 +389,19 @@ def _construct_prompt_tokens_tensor( vocab_size: int, device: torch.device, ) -> torch.Tensor: - prompt_token_ids: List[List[int]] = [] - for index in range(self.num_reqs): - prompt_token_ids.append(self.token_ids_cpu[ - index, :self.num_prompt_token_ids[index]].tolist()) - prompt_tokens_cpu_tensor = make_tensor_with_pad( - prompt_token_ids, - # use the value of vocab_size as a pad since we don't have a - # token_id of this value. - pad=vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=self.pin_memory, - ) + max_prompt_len = max(self.num_prompt_token_ids[:self.num_reqs]) + # use the value of vocab_size as a pad since we don't have a + # token_id of this value. + padded_prompts = np.full((self.num_reqs, max_prompt_len), + vocab_size, + dtype=np.int64) + for i in range(self.num_reqs): + padded_prompts[i, :self.num_prompt_token_ids[i]] =\ + self.token_ids_cpu[i, :self.num_prompt_token_ids[i]] + prompt_tokens_cpu_tensor = torch.from_numpy(padded_prompts).to("cpu") + if self.pin_memory: + prompt_tokens_cpu_tensor = \ + prompt_tokens_cpu_tensor.pin_memory() prompt_tokens_tensor = prompt_tokens_cpu_tensor.to(device=device, non_blocking=True) return prompt_tokens_tensor From c5ab213fe97bb62f7ca7fefc4bd41e1e1bb85512 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 17 Dec 2024 15:14:48 +0000 Subject: [PATCH 30/39] Add TODO --- vllm/v1/worker/gpu_input_batch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 71abd69da5642..dcbb7b7825aa5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -392,6 +392,9 @@ def _construct_prompt_tokens_tensor( max_prompt_len = max(self.num_prompt_token_ids[:self.num_reqs]) # use the value of vocab_size as a pad since we don't have a # token_id of this value. + # TODO - Add a method in vllm/utils.py to pad a numpy array similar + # to make_tensor_with_pad which takes a list and move the logic + # there. padded_prompts = np.full((self.num_reqs, max_prompt_len), vocab_size, dtype=np.int64) From c79fad5bad71092f31850bee64ee74882237aec7 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Dec 2024 07:57:04 +0000 Subject: [PATCH 31/39] Dummy Signed-off-by: Sourashis Roy --- vllm/v1/worker/gpu_input_batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index dcbb7b7825aa5..fcb1df1d27928 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -395,6 +395,7 @@ def _construct_prompt_tokens_tensor( # TODO - Add a method in vllm/utils.py to pad a numpy array similar # to make_tensor_with_pad which takes a list and move the logic # there. + padded_prompts = np.full((self.num_reqs, max_prompt_len), vocab_size, dtype=np.int64) From b3f7736a631f13be79120649aeca9376a8fb2869 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 18 Dec 2024 07:59:45 +0000 Subject: [PATCH 32/39] Dummy Signed-off-by: Sourashis Roy --- vllm/v1/worker/gpu_input_batch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index fcb1df1d27928..dcbb7b7825aa5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -395,7 +395,6 @@ def _construct_prompt_tokens_tensor( # TODO - Add a method in vllm/utils.py to pad a numpy array similar # to make_tensor_with_pad which takes a list and move the logic # there. - padded_prompts = np.full((self.num_reqs, max_prompt_len), vocab_size, dtype=np.int64) From c74b9bb546c0faa45b56042a89807cb16ebfc09a Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 20 Dec 2024 05:26:39 +0000 Subject: [PATCH 33/39] Rerun tests Signed-off-by: Sourashis Roy --- vllm/v1/worker/gpu_input_batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index dcbb7b7825aa5..fcb1df1d27928 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -395,6 +395,7 @@ def _construct_prompt_tokens_tensor( # TODO - Add a method in vllm/utils.py to pad a numpy array similar # to make_tensor_with_pad which takes a list and move the logic # there. + padded_prompts = np.full((self.num_reqs, max_prompt_len), vocab_size, dtype=np.int64) From 31ba41fd49e32edde4ffc532fdad56c252931b56 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 20 Dec 2024 05:26:57 +0000 Subject: [PATCH 34/39] Rerun tests Signed-off-by: Sourashis Roy --- vllm/v1/worker/gpu_input_batch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index fcb1df1d27928..dcbb7b7825aa5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -395,7 +395,6 @@ def _construct_prompt_tokens_tensor( # TODO - Add a method in vllm/utils.py to pad a numpy array similar # to make_tensor_with_pad which takes a list and move the logic # there. - padded_prompts = np.full((self.num_reqs, max_prompt_len), vocab_size, dtype=np.int64) From a781c1100cbcbaebd12ec0906c489e33f968d953 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 21 Dec 2024 12:37:59 -0800 Subject: [PATCH 35/39] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index dcbb7b7825aa5..018401f6f9e7e 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -147,7 +147,7 @@ def __init__( dtype=torch.float, device="cpu", pin_memory=pin_memory) - self.repetition_penalties_cpu =\ + self.repetition_penalties_cpu = \ self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() @@ -205,7 +205,7 @@ def add_request( self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) - self.frequency_penalties_cpu[req_index] =\ + self.frequency_penalties_cpu[req_index] = \ sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) @@ -292,7 +292,7 @@ def condense(self, empty_req_indices: List[int]) -> None: # block_table_cpu. self.token_ids_cpu[empty_index] = self.token_ids_cpu[ last_req_index] - self.num_prompt_token_ids[empty_index] =\ + self.num_prompt_token_ids[empty_index] = \ self.num_prompt_token_ids[last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] @@ -399,7 +399,7 @@ def _construct_prompt_tokens_tensor( vocab_size, dtype=np.int64) for i in range(self.num_reqs): - padded_prompts[i, :self.num_prompt_token_ids[i]] =\ + padded_prompts[i, :self.num_prompt_token_ids[i]] = \ self.token_ids_cpu[i, :self.num_prompt_token_ids[i]] prompt_tokens_cpu_tensor = torch.from_numpy(padded_prompts).to("cpu") if self.pin_memory: @@ -431,9 +431,9 @@ def no_top_k(self) -> bool: @property def no_penalties(self) -> bool: - return len(self.presence_penalties_reqs) == 0 and \ - len(self.frequency_penalties_reqs) == 0 and \ - len(self.repetition_penalties_reqs) == 0 + return (len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0) @property def max_num_logprobs(self) -> int: From 6bc8e0195cc2a406d38d6482aef0ad536ddd098f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 21 Dec 2024 12:53:56 -0800 Subject: [PATCH 36/39] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 018401f6f9e7e..668d37834eea0 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -151,7 +151,6 @@ def __init__( self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() - self.prompt_tokens_tensor: Optional[torch.Tensor] = None self.min_tokens: List[int] = [0] * max_num_reqs self.stop_token_ids: List[Set[int]] = [ set() for _ in range(max_num_reqs) @@ -244,6 +243,8 @@ def remove_request(self, req_id: str) -> Optional[int]: self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) + self.min_tokens[req_index] = 0 + self.stop_token_ids[req_index].clear() self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) @@ -259,6 +260,9 @@ def clear(self) -> None: self.frequency_penalties_reqs.clear() self.presence_penalties_reqs.clear() self.repetition_penalties_reqs.clear() + self.min_tokens = [0] * self.max_num_reqs + for stop_token_ids in self.stop_token_ids: + stop_token_ids.clear() self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() @@ -323,6 +327,7 @@ def make_sampling_metadata( req_id_output_token_ids: Dict[str, List[int]], skip_copy: bool = False, ) -> SamplingMetadata: + prompt_tokens_tensor: Optional[torch.Tensor] = None if not skip_copy: self.temperature[:self.num_reqs].copy_( self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) @@ -346,9 +351,8 @@ def make_sampling_metadata( # The prompt tokens are used only for applying penalties during # the sampling process. Hence copy these tensors only when # there are requests which need penalties to be applied. - self.prompt_tokens_tensor = \ - self._construct_prompt_tokens_tensor( - self.vocab_size, device=self.device) + prompt_tokens_tensor = self._construct_prompt_tokens_tensor( + self.vocab_size, device=self.device) output_token_ids: List[List[int]] = [] @@ -373,8 +377,7 @@ def make_sampling_metadata( no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=self.prompt_tokens_tensor[:self.num_reqs] \ - if self.prompt_tokens_tensor is not None else None, + prompt_token_ids=prompt_tokens_tensor, frequency_penalties=self.frequency_penalties[:self.num_reqs], presence_penalties=self.presence_penalties[:self.num_reqs], repetition_penalties=self.repetition_penalties[:self.num_reqs], From 0912e3e97a3339e065ced5cd7989ea088e54e5f2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 21 Dec 2024 13:12:33 -0800 Subject: [PATCH 37/39] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 668d37834eea0..4eefdfa56cdb7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -65,7 +65,7 @@ def __init__( ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - self.num_prompt_token_ids = np.empty(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) # Attention-related. self.block_table = torch.zeros( @@ -179,7 +179,7 @@ def add_request( # Copy the prompt token ids and output token ids. num_prompt_tokens = len(request.prompt_token_ids) - self.num_prompt_token_ids[req_index] = num_prompt_tokens + self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens @@ -243,8 +243,6 @@ def remove_request(self, req_id: str) -> Optional[int]: self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) - self.min_tokens[req_index] = 0 - self.stop_token_ids[req_index].clear() self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) @@ -260,9 +258,6 @@ def clear(self) -> None: self.frequency_penalties_reqs.clear() self.presence_penalties_reqs.clear() self.repetition_penalties_reqs.clear() - self.min_tokens = [0] * self.max_num_reqs - for stop_token_ids in self.stop_token_ids: - stop_token_ids.clear() self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() @@ -296,8 +291,8 @@ def condense(self, empty_req_indices: List[int]) -> None: # block_table_cpu. self.token_ids_cpu[empty_index] = self.token_ids_cpu[ last_req_index] - self.num_prompt_token_ids[empty_index] = \ - self.num_prompt_token_ids[last_req_index] + self.num_prompt_tokens[empty_index] = \ + self.num_prompt_tokens[last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table_cpu[empty_index] = self.block_table_cpu[ @@ -392,7 +387,7 @@ def _construct_prompt_tokens_tensor( vocab_size: int, device: torch.device, ) -> torch.Tensor: - max_prompt_len = max(self.num_prompt_token_ids[:self.num_reqs]) + max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() # use the value of vocab_size as a pad since we don't have a # token_id of this value. # TODO - Add a method in vllm/utils.py to pad a numpy array similar @@ -402,9 +397,9 @@ def _construct_prompt_tokens_tensor( vocab_size, dtype=np.int64) for i in range(self.num_reqs): - padded_prompts[i, :self.num_prompt_token_ids[i]] = \ - self.token_ids_cpu[i, :self.num_prompt_token_ids[i]] - prompt_tokens_cpu_tensor = torch.from_numpy(padded_prompts).to("cpu") + padded_prompts[i, :self.num_prompt_tokens[i]] = \ + self.token_ids_cpu[i, :self.num_prompt_tokens[i]] + prompt_tokens_cpu_tensor = torch.from_numpy(padded_prompts) if self.pin_memory: prompt_tokens_cpu_tensor = \ prompt_tokens_cpu_tensor.pin_memory() From 5dd4caa5b86519c61579154bc77fe32384a89f15 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 21 Dec 2024 13:34:08 -0800 Subject: [PATCH 38/39] optimize Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 41 ++++++++++++------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 4eefdfa56cdb7..dd8b60b17cc63 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -322,7 +322,7 @@ def make_sampling_metadata( req_id_output_token_ids: Dict[str, List[int]], skip_copy: bool = False, ) -> SamplingMetadata: - prompt_tokens_tensor: Optional[torch.Tensor] = None + prompt_token_ids: Optional[torch.Tensor] = None if not skip_copy: self.temperature[:self.num_reqs].copy_( self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) @@ -346,8 +346,7 @@ def make_sampling_metadata( # The prompt tokens are used only for applying penalties during # the sampling process. Hence copy these tensors only when # there are requests which need penalties to be applied. - prompt_tokens_tensor = self._construct_prompt_tokens_tensor( - self.vocab_size, device=self.device) + prompt_token_ids = self._make_prompt_token_ids_tensor() output_token_ids: List[List[int]] = [] @@ -372,7 +371,7 @@ def make_sampling_metadata( no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=prompt_tokens_tensor, + prompt_token_ids=prompt_token_ids, frequency_penalties=self.frequency_penalties[:self.num_reqs], presence_penalties=self.presence_penalties[:self.num_reqs], repetition_penalties=self.repetition_penalties[:self.num_reqs], @@ -382,30 +381,22 @@ def make_sampling_metadata( no_penalties=self.no_penalties, ) - def _construct_prompt_tokens_tensor( - self, - vocab_size: int, - device: torch.device, - ) -> torch.Tensor: + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() - # use the value of vocab_size as a pad since we don't have a + prompt_token_ids_cpu_tensor = torch.empty( + (self.num_reqs, max_prompt_len), + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory) + prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() + prompt_token_ids[:] = ( + self.token_ids_cpu[:self.num_reqs, :max_prompt_len]) + # Use the value of vocab_size as a pad since we don't have a # token_id of this value. - # TODO - Add a method in vllm/utils.py to pad a numpy array similar - # to make_tensor_with_pad which takes a list and move the logic - # there. - padded_prompts = np.full((self.num_reqs, max_prompt_len), - vocab_size, - dtype=np.int64) for i in range(self.num_reqs): - padded_prompts[i, :self.num_prompt_tokens[i]] = \ - self.token_ids_cpu[i, :self.num_prompt_tokens[i]] - prompt_tokens_cpu_tensor = torch.from_numpy(padded_prompts) - if self.pin_memory: - prompt_tokens_cpu_tensor = \ - prompt_tokens_cpu_tensor.pin_memory() - prompt_tokens_tensor = prompt_tokens_cpu_tensor.to(device=device, - non_blocking=True) - return prompt_tokens_tensor + prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, + non_blocking=True) @property def num_reqs(self) -> int: From 12ab994ad87ab34968407a1f39d3f78cad611820 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sun, 22 Dec 2024 07:42:35 +0000 Subject: [PATCH 39/39] Make prompt_token_ids a class variable Signed-off-by: Sourashis Roy --- vllm/v1/worker/gpu_input_batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index dd8b60b17cc63..6c4d300ec6efe 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -155,6 +155,7 @@ def __init__( self.stop_token_ids: List[Set[int]] = [ set() for _ in range(max_num_reqs) ] + self.prompt_token_ids: Optional[torch.Tensor] = None # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own @@ -322,7 +323,6 @@ def make_sampling_metadata( req_id_output_token_ids: Dict[str, List[int]], skip_copy: bool = False, ) -> SamplingMetadata: - prompt_token_ids: Optional[torch.Tensor] = None if not skip_copy: self.temperature[:self.num_reqs].copy_( self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) @@ -346,7 +346,7 @@ def make_sampling_metadata( # The prompt tokens are used only for applying penalties during # the sampling process. Hence copy these tensors only when # there are requests which need penalties to be applied. - prompt_token_ids = self._make_prompt_token_ids_tensor() + self.prompt_token_ids = self._make_prompt_token_ids_tensor() output_token_ids: List[List[int]] = [] @@ -371,7 +371,7 @@ def make_sampling_metadata( no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=prompt_token_ids, + prompt_token_ids=self.prompt_token_ids, frequency_penalties=self.frequency_penalties[:self.num_reqs], presence_penalties=self.presence_penalties[:self.num_reqs], repetition_penalties=self.repetition_penalties[:self.num_reqs],