Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 6, 2024
1 parent d58aa8f commit 07efbe9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,23 @@ def update_regex_vocab_mask(self):
return

# find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar)
first_grammar = next(grammar for grammar in self.grammars if grammar)

# maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask(
self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device,
)
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method

# Apply the mask
for i, grammar in enumerate(self.grammars):
if grammar and not grammar.finished:
grammar.fill_vocab_mask(self.vocab_mask, i)

# Move the mask to the device if needed
self.vocab_mask = grammar.move_vocab_mask(self.vocab_mask, self.device)
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)

def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
Expand Down

0 comments on commit 07efbe9

Please sign in to comment.