Skip to content

Commit

Permalink
[Refactor] Refactor GrammarMatcherBase (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubospica authored Jan 13, 2025
1 parent a01d8ca commit 5703f39
Show file tree
Hide file tree
Showing 22 changed files with 595 additions and 268 deletions.
3 changes: 3 additions & 0 deletions cpp/compiled_grammar_data_structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ struct AdaptiveTokenMask {
*/
class CompiledGrammar::Impl {
public:
/******************* The grammar and tokenizer info *******************/

/*! \brief The grammar for the GrammarMatcher. */
Grammar grammar;

/*! \brief The tokenizer information. */
TokenizerInfo tokenizer_info;

Expand Down
2 changes: 1 addition & 1 deletion cpp/ebnf_script_creator.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2023 by Contributors
* \file tokenizer.cc
* \file xgrammar/ebnf_script_creator.cc
*/
#include "ebnf_script_creator.h"

Expand Down
191 changes: 102 additions & 89 deletions cpp/compiler.cc → cpp/grammar_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "compiled_grammar_data_structure.h"
#include "grammar_data_structure.h"
#include "grammar_functor.h"
#include "grammar_matcher_base.h"
#include "support/thread_pool.h"
#include "support/thread_safe_cache.h"
Expand Down Expand Up @@ -217,23 +218,87 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
);
}

CompiledGrammar MultiThreadCompileGrammar(
const Grammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads
) {
/******************* GrammarCompiler::Impl *******************/

class GrammarCompiler::Impl {
public:
Impl(const TokenizerInfo& tokenizer_info, int max_threads, bool cache_enabled)
: tokenizer_info_(tokenizer_info),
max_threads_(max_threads),
cache_enabled_(cache_enabled),
compile_json_schema_cache_(GetCompileJSONSchemaCacheFunc(cache_enabled_)),
compile_builtin_json_grammar_cache_(GetCompileBuiltinJSONGrammarCacheFunc(cache_enabled_)),
compile_grammar_cache_(GetCompileGrammarCacheFunc(cache_enabled_)) {}

CompiledGrammar CompileBuiltinJSONGrammar();

CompiledGrammar CompileJSONSchema(
const std::string& schema,
bool any_whitespace,
std::optional<int> indent,
std::optional<std::pair<std::string, std::string>> separators,
bool strict_mode = true
);

CompiledGrammar CompileGrammar(const Grammar& grammar);

void ClearCache();

private:
/*! \brief Multi-thread compile the grammar. */
CompiledGrammar MultiThreadCompileGrammar(Grammar grammar, int max_threads);

/*! \brief The cache for the compiled grammar of a JSON schema. */
using SchemaKey =
std::tuple<std::string, bool, std::optional<int>, std::pair<std::string, std::string>, bool>;

/*! \brief The cache function for the compiled grammar of a JSON schema. */
std::function<CompiledGrammar(const SchemaKey&)> GetCompileJSONSchemaCacheFunc(bool cache_enabled
);

/*! \brief The cache function for the compiled grammar for pure JSON. */
std::function<CompiledGrammar()> GetCompileBuiltinJSONGrammarCacheFunc(bool cache_enabled);

using GrammarKey = std::pair<std::string, std::string>;
/*! \brief The cache function for the compiled grammar for a given grammar. */
std::function<CompiledGrammar(const GrammarKey&)> GetCompileGrammarCacheFunc(bool cache_enabled);

/*! \brief The vocabulary associated with this storage class. */
const TokenizerInfo tokenizer_info_;
/*! \brief The maximum number of threads to use. */
const int max_threads_;
/*! \brief Whether the cache is enabled. */
const bool cache_enabled_;
/*! \brief The cache for the compiled grammar of a JSON schema. */
ThreadSafeCache<SchemaKey, CompiledGrammar> compile_json_schema_cache_;
/*! \brief The cache for the compiled grammar for JSON. */
ThreadSafeCache<CompiledGrammar> compile_builtin_json_grammar_cache_;
/*! \brief The cache for the compiled grammar for bnf grammar. */
ThreadSafeCache<GrammarKey, CompiledGrammar> compile_grammar_cache_;
};

CompiledGrammar GrammarCompiler::Impl::MultiThreadCompileGrammar(Grammar grammar, int max_threads) {
using RuleExprType = Grammar::Impl::RuleExprType;

auto compiled_grammar_impl = std::make_shared<CompiledGrammar::Impl>();

compiled_grammar_impl->grammar = grammar;
compiled_grammar_impl->tokenizer_info = tokenizer_info;
compiled_grammar_impl->tokenizer_info = tokenizer_info_;

// Step 1. Compute the ids of rules that can be empty
compiled_grammar_impl->grammar->allow_empty_rule_ids = AllowEmptyRuleAnalyzer::Apply(grammar);

if (tokenizer_info.GetVocabSize() == 0) {
auto root_rule_id = grammar->GetRootRuleId();

if (tokenizer_info_.GetVocabSize() == 0) {
return CompiledGrammar(compiled_grammar_impl);
}

// Find the corresponding adaptive token mask for:
// Step 2. Compute the adaptive token mask cache
// The token mask cache is computed for these positions in the grammar:
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
// 2. All byte strings (with element_in_string=0, 1, 2, ...)
// since other positions will be expanded to the above positions

// TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly.
// Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do
Expand All @@ -246,7 +311,6 @@ CompiledGrammar MultiThreadCompileGrammar(
adaptive_token_mask_cache_mutex.emplace();
}

auto root_rule_id = grammar->GetRootRuleId();
for (int32_t rule_id = 0; rule_id < static_cast<int>(grammar->NumRules()); ++rule_id) {
auto rule = grammar->GetRule(rule_id);
auto rule_body = grammar->GetRuleExpr(rule.body_expr_id);
Expand All @@ -268,8 +332,8 @@ CompiledGrammar MultiThreadCompileGrammar(
auto add_adaptive_token_mask = [&](const StackElement& stack_element) {
auto grammar_matcher = GrammarMatcherForTokenMaskCache(grammar, stack_element);
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
tokenizer_info.GetVocabSize(),
tokenizer_info.GetSortedDecodedVocab(),
tokenizer_info_.GetVocabSize(),
tokenizer_info_.GetSortedDecodedVocab(),
rule_id != root_rule_id
);
if (max_threads > 1) {
Expand Down Expand Up @@ -311,95 +375,45 @@ CompiledGrammar MultiThreadCompileGrammar(
if (max_threads > 1) {
thread_pool->Join();
}

return CompiledGrammar(compiled_grammar_impl);
}

/******************* GrammarCompiler::Impl *******************/

class GrammarCompiler::Impl {
public:
Impl(const TokenizerInfo& tokenizer_info, int max_threads, bool cache_enabled)
: tokenizer_info_(tokenizer_info),
max_threads_(max_threads),
cache_enabled_(cache_enabled),
compile_json_schema_cache_(GetCompileJSONSchemaCacheFunc(cache_enabled_)),
compile_builtin_json_grammar_cache_(GetCompileBuiltinJSONGrammarCacheFunc(cache_enabled_)),
compile_grammar_cache_(GetCompileGrammarCacheFunc(cache_enabled_)) {}

CompiledGrammar CompileBuiltinJSONGrammar();

CompiledGrammar CompileJSONSchema(
const std::string& schema,
bool any_whitespace,
std::optional<int> indent,
std::optional<std::pair<std::string, std::string>> separators,
bool strict_mode = true
);

CompiledGrammar CompileGrammar(const Grammar& grammar);

void ClearCache();

private:
/*! \brief The cache for the compiled grammar of a JSON schema. */
using SchemaKey =
std::tuple<std::string, bool, std::optional<int>, std::pair<std::string, std::string>, bool>;

std::function<CompiledGrammar(const SchemaKey&)> GetCompileJSONSchemaCacheFunc(bool cache_enabled
) {
if (!cache_enabled) {
return nullptr;
}
return [&](const SchemaKey& key) {
auto [schema, any_whitespace, indent, separators, strict_mode] = key;
auto grammar =
Grammar::FromJSONSchema(schema, any_whitespace, indent, separators, strict_mode);
return MultiThreadCompileGrammar(grammar, tokenizer_info_, max_threads_);
};
std::function<CompiledGrammar(const GrammarCompiler::Impl::SchemaKey&)>
GrammarCompiler::Impl::GetCompileJSONSchemaCacheFunc(bool cache_enabled) {
if (!cache_enabled) {
return nullptr;
}
return [&](const SchemaKey& key) {
auto [schema, any_whitespace, indent, separators, strict_mode] = key;
auto grammar = Grammar::FromJSONSchema(schema, any_whitespace, indent, separators, strict_mode);
return MultiThreadCompileGrammar(grammar, max_threads_);
};
}

std::function<CompiledGrammar()> GetCompileBuiltinJSONGrammarCacheFunc(bool cache_enabled) {
if (!cache_enabled) {
return nullptr;
}
return [&]() {
return MultiThreadCompileGrammar(
Grammar::BuiltinJSONGrammar(), tokenizer_info_, max_threads_
);
};
std::function<CompiledGrammar()> GrammarCompiler::Impl::GetCompileBuiltinJSONGrammarCacheFunc(
bool cache_enabled
) {
if (!cache_enabled) {
return nullptr;
}
return [&]() { return MultiThreadCompileGrammar(Grammar::BuiltinJSONGrammar(), max_threads_); };
}

using GrammarKey = std::pair<std::string, std::string>;

std::function<CompiledGrammar(const GrammarKey&)> GetCompileGrammarCacheFunc(bool cache_enabled) {
if (!cache_enabled) {
return nullptr;
}
return [&](const GrammarKey& key) {
auto [grammar_str, root_rule_name] = key;
return MultiThreadCompileGrammar(
Grammar::FromEBNF(grammar_str, root_rule_name), tokenizer_info_, max_threads_
);
};
std::function<CompiledGrammar(const GrammarCompiler::Impl::GrammarKey&)>
GrammarCompiler::Impl::GetCompileGrammarCacheFunc(bool cache_enabled) {
if (!cache_enabled) {
return nullptr;
}

/*! \brief The vocabulary associated with this storage class. */
const TokenizerInfo tokenizer_info_;
/*! \brief The maximum number of threads to use. */
const int max_threads_;
/*! \brief Whether the cache is enabled. */
const bool cache_enabled_;
/*! \brief The cache for the compiled grammar of a JSON schema. */
ThreadSafeCache<SchemaKey, CompiledGrammar> compile_json_schema_cache_;
/*! \brief The cache for the compiled grammar for JSON. */
ThreadSafeCache<CompiledGrammar> compile_builtin_json_grammar_cache_;
/*! \brief The cache for the compiled grammar for bnf grammar. */
ThreadSafeCache<GrammarKey, CompiledGrammar> compile_grammar_cache_;
};
return [&](const GrammarKey& key) {
auto [grammar_str, root_rule_name] = key;
return MultiThreadCompileGrammar(Grammar::FromEBNF(grammar_str, root_rule_name), max_threads_);
};
}

CompiledGrammar GrammarCompiler::Impl::CompileBuiltinJSONGrammar() {
if (!cache_enabled_) {
return MultiThreadCompileGrammar(Grammar::BuiltinJSONGrammar(), tokenizer_info_, max_threads_);
return MultiThreadCompileGrammar(Grammar::BuiltinJSONGrammar(), max_threads_);
}
return compile_builtin_json_grammar_cache_.Get();
}
Expand All @@ -414,7 +428,6 @@ CompiledGrammar GrammarCompiler::Impl::CompileJSONSchema(
if (!cache_enabled_) {
return MultiThreadCompileGrammar(
Grammar::FromJSONSchema(schema, any_whitespace, indent, separators, strict_mode),
tokenizer_info_,
max_threads_
);
}
Expand All @@ -427,7 +440,7 @@ CompiledGrammar GrammarCompiler::Impl::CompileJSONSchema(

CompiledGrammar GrammarCompiler::Impl::CompileGrammar(const Grammar& grammar) {
if (!cache_enabled_) {
return MultiThreadCompileGrammar(grammar, tokenizer_info_, max_threads_);
return MultiThreadCompileGrammar(grammar, max_threads_);
}
auto key = std::make_pair(grammar.ToString(), grammar->GetRootRule().name);
return compile_grammar_cache_.Get(key);
Expand Down
6 changes: 6 additions & 0 deletions cpp/grammar_data_structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ class Grammar::Impl {
/*! \brief The id of the root rule. */
int32_t root_rule_id_ = -1;

public:
/******************* Aux information for matching *******************/

/*! \brief The ids of the rules that are allowed to be empty. */
std::vector<int32_t> allow_empty_rule_ids;

friend class GrammarBuilder;
friend class GrammarSerializer;
friend class GrammarDeserializer;
Expand Down
Loading

0 comments on commit 5703f39

Please sign in to comment.