Skip to content

Commit

Permalink
[Feature] Grammar concat and union (#149)
Browse files Browse the repository at this point in the history
This PR provides APIs for finding the concatenation and union for grammars. The concat API is provided as a static function of Grammar, while union is a private function in testing. The latter will be moved out later with more tests.
  • Loading branch information
Ubospica authored Jan 11, 2025
1 parent 50dac1a commit 608dd5c
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 25 deletions.
8 changes: 8 additions & 0 deletions cpp/grammar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ Grammar Grammar::BuiltinJSONGrammar() {
return grammar;
}

Grammar Grammar::Union(const std::vector<Grammar>& grammars) {
return GrammarUnionFunctor::Apply(grammars);
}

Grammar Grammar::Concat(const std::vector<Grammar>& grammars) {
return GrammarConcatFunctor::Apply(grammars);
}

std::ostream& operator<<(std::ostream& os, const Grammar& grammar) {
os << grammar.ToString();
return os;
Expand Down
13 changes: 13 additions & 0 deletions cpp/grammar_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ class GrammarBuilder {
int32_t root_rule_id = GetRuleId(root_rule_name);
XGRAMMAR_CHECK(root_rule_id != -1)
<< "The root rule with name \"" << root_rule_name << "\" is not found.";
return Get(root_rule_id);
}

/*!
* \brief Get the result grammar. This function will also set the root rule to the rule with
* the specified id. The rule should be already added to the grammar.
* \param root_rule_id The id of the root rule.
*/
Grammar Get(int32_t root_rule_id) {
XGRAMMAR_CHECK(
root_rule_id >= 0 && root_rule_id < static_cast<int32_t>(grammar_->rules_.size())
) << "The root rule id "
<< root_rule_id << " is out of bound.";
grammar_->root_rule_id_ = root_rule_id;

return Grammar(grammar_);
Expand Down
134 changes: 120 additions & 14 deletions cpp/grammar_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace xgrammar {

/*************************** Impl of grammar functors ***************************/

/*!
* \brief Eliminates single-element sequence or choice or character class in the grammar.
* \example `A ::= choices("a")` --> `A ::= "a"` (the body is a string)
Expand Down Expand Up @@ -310,22 +312,126 @@ class ByteStringFuser : public GrammarMutator {
}
};

// Return the list of all normalizers in the class. The normalizers are applied one by one.
std::vector<std::unique_ptr<GrammarMutator>> GrammarNormalizer::GetNormalizerList() {
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators;
normalizer_mutators.emplace_back(std::make_unique<SingleElementExprEliminator>());
normalizer_mutators.emplace_back(std::make_unique<NestedRuleUnwrapper>());
normalizer_mutators.emplace_back(std::make_unique<ByteStringFuser>());
return normalizer_mutators;
}
class GrammarNormalizerImpl : public GrammarMutator {
public:
GrammarNormalizerImpl() = default;

Grammar GrammarNormalizer::Apply(const Grammar& grammar) {
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators = GetNormalizerList();
old_grammar_ = grammar;
for (auto& mutator : normalizer_mutators) {
old_grammar_ = mutator->Apply(old_grammar_);
Grammar Apply(const Grammar& grammar) final {
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators = GetNormalizerList();
old_grammar_ = grammar;
for (auto& mutator : normalizer_mutators) {
old_grammar_ = mutator->Apply(old_grammar_);
}
return old_grammar_;
}

private:
// Return the list of all normalizers in the class. The normalizers are applied one by one.
std::vector<std::unique_ptr<GrammarMutator>> GetNormalizerList() {
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators;
normalizer_mutators.emplace_back(std::make_unique<SingleElementExprEliminator>());
normalizer_mutators.emplace_back(std::make_unique<NestedRuleUnwrapper>());
normalizer_mutators.emplace_back(std::make_unique<ByteStringFuser>());
return normalizer_mutators;
}
};

class SubGrammarAdder : public GrammarMutator {
public:
SubGrammarAdder() = default;

protected:
/*!
* \brief Visit a subgrammar and add the rules to the builder.
* \param grammar The subgrammar to visit.
* \return The new id of the root rule of this subgrammar.
*/
int32_t VisitSubGrammar(const Grammar& grammar) {
old_grammar_ = grammar;
new_rule_ids_names.reserve(grammar->NumRules());
new_rule_ids_names.clear();
for (int i = 0; i < static_cast<int>(grammar->NumRules()); ++i) {
auto new_name = builder_.GetNewRuleName(grammar->GetRule(i).name);
auto new_id = builder_.AddEmptyRule(new_name);
new_rule_ids_names.emplace_back(new_id, new_name);
}
for (int i = 0; i < static_cast<int>(grammar->NumRules()); ++i) {
auto rule = grammar->GetRule(i);
cur_rule_name_ = new_rule_ids_names[i].second;
auto new_body_expr_id = VisitExpr(rule.body_expr_id);
builder_.UpdateRuleBody(new_rule_ids_names[i].first, new_body_expr_id);
auto new_lookahead_assertion_id = VisitLookaheadAssertion(rule.lookahead_assertion_id);
builder_.AddLookaheadAssertion(new_rule_ids_names[i].first, new_lookahead_assertion_id);
}
return new_rule_ids_names[grammar->GetRootRuleId()].first;
}

int32_t VisitRuleRef(const RuleExpr& rule_expr) final {
return builder_.AddRuleRef(new_rule_ids_names[rule_expr[0]].first);
}

std::vector<std::pair<int32_t, std::string>> new_rule_ids_names;
};

class GrammarUnionFunctorImpl : public SubGrammarAdder {
public:
GrammarUnionFunctorImpl() = default;

Grammar Apply(const std::vector<Grammar>& grammars) {
builder_ = GrammarBuilder();
auto root_rule_id = builder_.AddEmptyRule("root");

std::vector<int32_t> new_root_choices;
new_root_choices.reserve(grammars.size());

for (const auto& grammar : grammars) {
auto new_root_id_for_grammar = VisitSubGrammar(grammar);
auto new_rule_ref = builder_.AddRuleRef(new_root_id_for_grammar);
auto new_rule_ref_seq = builder_.AddSequence({new_rule_ref});
new_root_choices.push_back(new_rule_ref_seq);
}

builder_.UpdateRuleBody(root_rule_id, builder_.AddChoices(new_root_choices));
return builder_.Get(root_rule_id);
}
};

class GrammarConcatFunctorImpl : public SubGrammarAdder {
public:
GrammarConcatFunctorImpl() = default;

Grammar Apply(const std::vector<Grammar>& grammars) {
builder_ = GrammarBuilder();
auto root_rule_id = builder_.AddEmptyRule("root");

std::vector<int32_t> new_root_sequence;
new_root_sequence.reserve(grammars.size());

for (const auto& grammar : grammars) {
auto new_root_id_for_grammar = VisitSubGrammar(grammar);
auto new_rule_ref = builder_.AddRuleRef(new_root_id_for_grammar);
new_root_sequence.push_back(new_rule_ref);
}

auto new_root_seq = builder_.AddSequence(new_root_sequence);
builder_.UpdateRuleBody(root_rule_id, builder_.AddChoices({new_root_seq}));

return builder_.Get(root_rule_id);
}
return old_grammar_;
};

/*************************** Forward grammar functors to their impl ***************************/

Grammar GrammarNormalizer::Apply(const Grammar& grammar) {
return GrammarNormalizerImpl().Apply(grammar);
}

Grammar GrammarUnionFunctor::Apply(const std::vector<Grammar>& grammars) {
return GrammarUnionFunctorImpl().Apply(grammars);
}

Grammar GrammarConcatFunctor::Apply(const std::vector<Grammar>& grammars) {
return GrammarConcatFunctorImpl().Apply(grammars);
}

} // namespace xgrammar
24 changes: 19 additions & 5 deletions cpp/grammar_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,32 @@ using GrammarVisitor = GrammarFunctor<void, ReturnType>;
*/
using GrammarMutator = GrammarFunctor<int32_t, Grammar>;

/*************************** Grammar manipulation methods ***************************/
/****** All below methods are implemented as functor to hide the implementation ******/

/*!
* \brief Normalize a Grammar: expand the nested rules, combine consequent sequences and strings,
* etc.
*/
class GrammarNormalizer : public GrammarMutator {
class GrammarNormalizer {
public:
using GrammarMutator::GrammarMutator;
static Grammar Apply(const Grammar& grammar);
};

Grammar Apply(const Grammar& grammar) final;
/*!
* \brief Find the union of multiple grammars as a new grammar.
*/
class GrammarUnionFunctor {
public:
static Grammar Apply(const std::vector<Grammar>& grammars);
};

private:
std::vector<std::unique_ptr<GrammarMutator>> GetNormalizerList();
/*!
* \brief Find the concatenation of multiple grammars as a new grammar.
*/
class GrammarConcatFunctor {
public:
static Grammar Apply(const std::vector<Grammar>& grammars);
};

} // namespace xgrammar
Expand Down
4 changes: 3 additions & 1 deletion cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
.def_static("from_ebnf", &Grammar::FromEBNF)
.def_static("from_json_schema", &Grammar::FromJSONSchema)
.def_static("from_regex", &Grammar::FromRegex)
.def_static("builtin_json_grammar", &Grammar::BuiltinJSONGrammar);
.def_static("builtin_json_grammar", &Grammar::BuiltinJSONGrammar)
.def_static("union", &Grammar::Union)
.def_static("concat", &Grammar::Concat);

auto pyCompiledGrammar = py::class_<CompiledGrammar>(m, "CompiledGrammar");
pyCompiledGrammar.def_property_readonly("grammar", &CompiledGrammar::GetGrammar)
Expand Down
17 changes: 13 additions & 4 deletions include/xgrammar/grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,20 @@ class Grammar {
static Grammar BuiltinJSONGrammar();

/*!
* \brief Convert regex string to EBNF grammar string.
* \param regex The regex string.
* \returns The EBNF grammar string.
* \brief Create a grammar that matches any of the grammars in the list. That is equivalent to
* using the `|` operator to concatenate the grammars in the list.
* \param grammars The grammars to create the union of.
* \returns The union of the grammars.
*/
// static std::string _RegexToEBNF(const std::string& regex);
static Grammar Union(const std::vector<Grammar>& grammars);

/*!
* \brief Create a grammar that matches the concatenation of the grammars in the list. That is
* equivalent to using the `+` operator to concatenate the grammars in the list.
* \param grammars The grammars to create the concatenation of.
* \returns The concatenation of the grammars.
*/
static Grammar Concat(const std::vector<Grammar>& grammars);

/*! \brief Print a BNF grammar. */
friend std::ostream& operator<<(std::ostream& os, const Grammar& grammar);
Expand Down
20 changes: 19 additions & 1 deletion python/xgrammar/grammar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module provides classes representing grammars."""

import json
from typing import Optional, Tuple, Type, Union
from typing import List, Optional, Tuple, Type, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -159,3 +159,21 @@ def builtin_json_grammar() -> "Grammar":
The JSON grammar.
"""
return Grammar._create_from_handle(_core.Grammar.builtin_json_grammar())

@staticmethod
def concat(*grammars: "Grammar") -> "Grammar":
"""Create a grammar that matches the concatenation of the grammars in the list. That is
equivalent to using the `+` operator to concatenate the grammars in the list.
Parameters
----------
grammars : List[Grammar]
The grammars to create the concatenation of.
Returns
-------
grammar : Grammar
The concatenation of the grammars.
"""
grammar_handles = [grammar._handle for grammar in grammars]
return Grammar._create_from_handle(_core.Grammar.concat(grammar_handles))
36 changes: 36 additions & 0 deletions python/xgrammar/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,44 @@ def _get_masked_tokens_from_bitmask(
def _get_matcher_from_grammar_and_tokenizer_info(
grammar: Union[Grammar, str], tokenizer_info: Optional[TokenizerInfo] = None, **kwargs
) -> GrammarMatcher:
"""Create a GrammarMatcher from a grammar and tokenizer info.
Parameters
----------
grammar : Union[Grammar, str]
The grammar to create the matcher from. Can be either a Grammar object or a string
containing EBNF grammar.
tokenizer_info : Optional[TokenizerInfo], default: None
Information about the tokenizer to use with this grammar. If None, an empty
TokenizerInfo will be created.
**kwargs
Additional keyword arguments to pass to the GrammarMatcher constructor.
Returns
-------
matcher : GrammarMatcher
The created grammar matcher.
"""
if tokenizer_info is None:
tokenizer_info = TokenizerInfo([])
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
compiled_grammar = grammar_compiler.compile_grammar(grammar)
return GrammarMatcher(compiled_grammar, **kwargs)


def _get_grammar_union(*grammars: "Grammar") -> "Grammar":
"""Create a grammar that matches any of the grammars in the list. That is equivalent to
using the `|` operator to concatenate the grammars in the list.
Parameters
----------
grammars : List[Grammar]
The grammars to create the union of.
Returns
-------
grammar : Grammar
The union of the grammars.
"""
grammar_handles = [grammar._handle for grammar in grammars]
return Grammar._create_from_handle(_core.Grammar.union(grammar_handles))
Loading

0 comments on commit 608dd5c

Please sign in to comment.