From 314ade0c257b2b6e6f7c0667b668fd2b4036f52d Mon Sep 17 00:00:00 2001
From: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Date: Thu, 26 Dec 2024 16:37:38 -0800
Subject: [PATCH 1/2] [V1] Fix yapf

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
---
 vllm/v1/sample/ops/penalties.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py
index 91ebaf9269f32..0ac9596108f62 100644
--- a/vllm/v1/sample/ops/penalties.py
+++ b/vllm/v1/sample/ops/penalties.py
@@ -2,8 +2,8 @@
 
 import torch
 
-from vllm.model_executor.layers.utils import (
-    apply_penalties as _apply_penalties)
+from vllm.model_executor.layers.utils import (apply_penalties as
+                                              _apply_penalties)
 from vllm.utils import is_pin_memory_available, make_tensor_with_pad
 
 

From 66f8cbcfb738e34d37c72c72834794e12ddfa3c2 Mon Sep 17 00:00:00 2001
From: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Date: Thu, 26 Dec 2024 16:43:59 -0800
Subject: [PATCH 2/2] Fix

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
---
 vllm/v1/sample/ops/penalties.py | 24 +++++++++++++-----------
 vllm/v1/sample/sampler.py       | 16 ++++++++--------
 2 files changed, 21 insertions(+), 19 deletions(-)

diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py
index 0ac9596108f62..2796d049457d0 100644
--- a/vllm/v1/sample/ops/penalties.py
+++ b/vllm/v1/sample/ops/penalties.py
@@ -2,8 +2,7 @@
 
 import torch
 
-from vllm.model_executor.layers.utils import (apply_penalties as
-                                              _apply_penalties)
+from vllm.model_executor.layers.utils import apply_penalties
 from vllm.utils import is_pin_memory_available, make_tensor_with_pad
 
 
@@ -17,27 +16,30 @@ def apply_min_token_penalties(logits: torch.Tensor,
     """
     min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
     for index, min_token in enumerate(min_tokens):
-        if (len(output_token_ids[index]) < min_token):
+        if len(output_token_ids[index]) < min_token:
             for stop_token_id in stop_token_ids[index]:
                 min_tokens_logits_to_penalize.append((index, stop_token_id))
     if min_tokens_logits_to_penalize:
         logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
 
 
-def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
-                    presence_penalties: torch.Tensor,
-                    frequency_penalties: torch.Tensor,
-                    repetition_penalties: torch.Tensor,
-                    output_token_ids: List[List[int]]) -> torch.Tensor:
+def apply_all_penalties(
+    logits: torch.Tensor,
+    prompt_token_ids: torch.Tensor,
+    presence_penalties: torch.Tensor,
+    frequency_penalties: torch.Tensor,
+    repetition_penalties: torch.Tensor,
+    output_token_ids: List[List[int]],
+) -> torch.Tensor:
     """
     Applies presence, frequency and repetition penalties to the logits.
     """
     _, vocab_size = logits.shape
     output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
                                           logits.device)
-    return _apply_penalties(logits, prompt_token_ids, output_tokens_t,
-                            presence_penalties, frequency_penalties,
-                            repetition_penalties)
+    return apply_penalties(logits, prompt_token_ids, output_tokens_t,
+                           presence_penalties, frequency_penalties,
+                           repetition_penalties)
 
 
 def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py
index 1e38453a0ff28..7cd42ca211a22 100644
--- a/vllm/v1/sample/sampler.py
+++ b/vllm/v1/sample/sampler.py
@@ -6,8 +6,8 @@
 
 from vllm.v1.outputs import SamplerOutput
 from vllm.v1.sample.metadata import SamplingMetadata
-from vllm.v1.sample.ops.penalties import (apply_min_token_penalties,
-                                          apply_penalties)
+from vllm.v1.sample.ops.penalties import (apply_all_penalties,
+                                          apply_min_token_penalties)
 from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
 
 _SAMPLING_EPS = 1e-5
@@ -127,10 +127,10 @@ def apply_penalties(
                                   sampling_metadata.min_tokens)
         if not sampling_metadata.no_penalties:
             assert sampling_metadata.prompt_token_ids is not None
-            logits = apply_penalties(logits,
-                                     sampling_metadata.prompt_token_ids,
-                                     sampling_metadata.presence_penalties,
-                                     sampling_metadata.frequency_penalties,
-                                     sampling_metadata.repetition_penalties,
-                                     sampling_metadata.output_token_ids)
+            logits = apply_all_penalties(
+                logits, sampling_metadata.prompt_token_ids,
+                sampling_metadata.presence_penalties,
+                sampling_metadata.frequency_penalties,
+                sampling_metadata.repetition_penalties,
+                sampling_metadata.output_token_ids)
         return logits