diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index f30f5eb63f7..61aa4de75c2 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -158,15 +158,15 @@ def update_regex_vocab_mask(self): return # find a grammar from the list - grammar = next(grammar for grammar in self.grammars if grammar) + first_grammar = next(grammar for grammar in self.grammars if grammar) # maybe we can reuse the existing mask? - self.vocab_mask = grammar.allocate_vocab_mask( + self.vocab_mask = first_grammar.allocate_vocab_mask( vocab_size=self.vocab_size, batch_size=len(self.temperatures), device=self.device, ) - self.apply_mask = type(grammar).apply_vocab_mask # force to use static method + self.apply_mask = first_grammar.apply_vocab_mask # force to use static method # Apply the mask for i, grammar in enumerate(self.grammars): @@ -174,7 +174,7 @@ def update_regex_vocab_mask(self): grammar.fill_vocab_mask(self.vocab_mask, i) # Move the mask to the device if needed - self.vocab_mask = grammar.move_vocab_mask(self.vocab_mask, self.device) + self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices)