Skip to content

Commit

Permalink
Overlap for xgrammar
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 6, 2024
1 parent 1dd38c6 commit d58aa8f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
4 changes: 4 additions & 0 deletions python/sglang/srt/constrained/outlines_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def allocate_vocab_mask(
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)

@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask

def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
Expand Down
9 changes: 4 additions & 5 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,11 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if vocab_mask.device.type != logits.device.type:
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask = vocab_mask.to(logits.device)
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask)

def copy(self):
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,14 @@ def update_regex_vocab_mask(self):
)
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method

# Apply the mask
for i, grammar in enumerate(self.grammars):
if grammar and not grammar.finished:
grammar.fill_vocab_mask(self.vocab_mask, i)

# Move the mask to the device if needed
self.vocab_mask = grammar.move_vocab_mask(self.vocab_mask, self.device)

def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)

Expand Down

0 comments on commit d58aa8f

Please sign in to comment.