Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiled Pattern Matching #204

Merged
merged 8 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"

[extensions]
Plotting = ["GraphViz"]

[compat]
AutoHashEquals = "2.1.0"
DocStringExtensions = "0.8, 0.9"
Reexport = "0.2, 1"
TimerOutputs = "0.5"
TermInterface = "0.4.1"
TimerOutputs = "0.5"
julia = "1.9"

[extras]
Expand All @@ -26,9 +32,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "SafeTestsets", "Literate"]

[weakdeps]
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"

[extensions]
Plotting = ["GraphViz"]
4 changes: 3 additions & 1 deletion examples/propositional_logic_theory.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# # Rewriting

using Metatheory, TermInterface

fold = @theory p q begin
(p::Bool == q::Bool) => (p == q)
(p::Bool || q::Bool) => (p || q)
Expand Down Expand Up @@ -74,7 +76,7 @@ function prove(
params.goal = (g::EGraph) -> in_same_class(g, ids...)
saturate!(g, t, params)
ex = extract!(g, astsize)
if !Metatheory.isexpr(ex)
if !TermInterface.isexpr(ex)
return ex
end
if hash(ex) ∈ hist
Expand Down
2 changes: 1 addition & 1 deletion src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Metatheory.Patterns
using Metatheory.Rules
using Metatheory.VecExprModule

using Metatheory: alwaystrue, cleanast, UNDEF_ID_VEC, should_quote_operation, OptBuffer
using Metatheory: alwaystrue, cleanast, UNDEF_ID_VEC, maybe_quote_operation, OptBuffer

import Metatheory: to_expr

Expand Down
2 changes: 1 addition & 1 deletion src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ function addexpr!(g::EGraph, se)::Id
v_set_head!(n, add_constant!(g, h))

# get the signature from op and arity
v_set_signature!(n, hash(should_quote_operation(h) ? nameof(h) : h, hash(ar)))
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))

for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
Expand Down
6 changes: 3 additions & 3 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ function eqsat_search!(

if rule isa BidirRule
for i in ids_left
n_matches += rule.ematcher_new_left!(g, rule_idx, i, rule.ematcher_stack, ematch_buffer)
n_matches += rule.ematcher_new_left!(g, rule_idx, i, rule.stack, ematch_buffer)
end
for i in ids_right
n_matches += rule.ematcher_new_right!(g, rule_idx, i, rule.ematcher_stack, ematch_buffer)
n_matches += rule.ematcher_new_right!(g, rule_idx, i, rule.stack, ematch_buffer)
end
else
for i in ids_left
n_matches += rule.ematcher!(g, rule_idx, i, rule.ematcher_stack, ematch_buffer)
n_matches += rule.ematcher!(g, rule_idx, i, rule.stack, ematch_buffer)
end
end
n_matches - prev_matches > 0 && @debug "Rule $rule_idx: $rule produced $(n_matches - prev_matches) matches"
Expand Down
12 changes: 7 additions & 5 deletions src/Metatheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ using Reexport
function to_expr end

# TODO: document
function should_quote_operation end
should_quote_operation(::Function) = true
should_quote_operation(x) = false
Base.@inline maybe_quote_operation(x::Union{Function,DataType}) = nameof(x)
Base.@inline maybe_quote_operation(x) = x

include("docstrings.jl")

Expand All @@ -22,8 +21,7 @@ export OptBuffer

const UNDEF_ID_VEC = Vector{Id}(undef, 0)

using TermInterface
using TermInterface: isexpr
@reexport using TermInterface

"""
@matchable struct Foo fields... end [HeadType]
Expand Down Expand Up @@ -64,6 +62,10 @@ export @timer
include("Patterns.jl")
@reexport using .Patterns

include("match_compiler.jl")
export match_compile


include("ematch_compiler.jl")
export ematch_compile

Expand Down
6 changes: 4 additions & 2 deletions src/Patterns.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Patterns

using Metatheory: cleanast, alwaystrue, should_quote_operation
using Metatheory: cleanast, alwaystrue, maybe_quote_operation
using AutoHashEquals
using TermInterface
using Metatheory.VecExprModule
Expand Down Expand Up @@ -92,7 +92,9 @@ struct PatExpr <: AbstractPat
n::VecExpr
function PatExpr(iscall, op, args::Vector)
op_hash = hash(op)
qop, qop_hash = should_quote_operation(op) ? (nameof(op), hash(nameof(op))) : (op, op_hash)
# Should call `nameof` on op if Function or DataType. Identity otherwise
qop = maybe_quote_operation(op)
qop_hash = hash(qop)
ar = length(args)
signature = hash(qop, hash(ar))

Expand Down
31 changes: 12 additions & 19 deletions src/Rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,25 @@ variables.
matcher
patvars::Vector{Symbol}
ematcher!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function RewriteRule(l, r, ematcher!)
function RewriteRule(l, r, matcher!, ematcher!)
pvars = patvars(l) ∪ patvars(r)
# sort!(pvars)
setdebrujin!(l, pvars)
setdebrujin!(r, pvars)
RewriteRule(l, r, matcher(l), pvars, ematcher!, OptBuffer{UInt16}(STACK_SIZE))
RewriteRule(l, r, matcher!, pvars, ematcher!, OptBuffer{UInt16}(STACK_SIZE))
end

Base.show(io::IO, r::RewriteRule) = print(io, :($(r.left) --> $(r.right)))


function (r::RewriteRule)(term)
# n == 1 means that exactly one term of the input (term,) was matched
success(bindings, n) = n == 1 ? instantiate(term, r.right, bindings) : nothing

success(pvars...) = instantiate(term, r.right, pvars)
try
r.matcher(success, (term,), EMPTY_DICT)
r.matcher(term, success, r.stack)
catch err
rethrow(err)
throw(RuleRewriteError(r, term, err))
Expand All @@ -98,7 +97,7 @@ with the EGraphs backend.
patvars::Vector{Symbol}
ematcher_new_left!
ematcher_new_right!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function EqualityRule(l, r, ematcher_new_left!, ematcher_new_right!)
Expand Down Expand Up @@ -136,7 +135,7 @@ backend. If two terms, corresponding to the left and right hand side of an
patvars::Vector{Symbol}
ematcher_new_left!
ematcher_new_right!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function UnequalRule(l, r, ematcher_new_left!, ematcher_new_right!)
Expand Down Expand Up @@ -177,30 +176,24 @@ Dynamic rule
matcher
patvars::Vector{Symbol} # useful set of pattern variables
ematcher!
ematcher_stack::OptBuffer{UInt16}
stack::OptBuffer{UInt16}
end

function DynamicRule(l, r::Function, ematcher!, rhs_code = nothing)
function DynamicRule(l, r::Function, matcher, ematcher!, rhs_code = nothing)
pvars = patvars(l)
setdebrujin!(l, pvars)
isnothing(rhs_code) && (rhs_code = repr(rhs_code))

DynamicRule(l, r, rhs_code, matcher(l), pvars, ematcher!, OptBuffer{UInt16}(512))
DynamicRule(l, r, rhs_code, matcher, pvars, ematcher!, OptBuffer{UInt16}(512))
end


Base.show(io::IO, r::DynamicRule) = print(io, :($(r.left) => $(r.rhs_code)))

function (r::DynamicRule)(term)
# n == 1 means that exactly one term of the input (term,) was matched
success(bindings, n) =
if n == 1
bvals = [bindings[i] for i in 1:length(r.patvars)]
return r.rhs_fun(term, nothing, bvals...)
end

success(bindings...) = r.rhs_fun(term, nothing, bindings...)
try
return r.matcher(success, (term,), EMPTY_DICT)
return r.matcher(term, success, r.stack)
catch err
throw(RuleRewriteError(r, term, err))
end
Expand Down
31 changes: 21 additions & 10 deletions src/Syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Metatheory.Patterns
using Metatheory.Rules
using TermInterface

using Metatheory: alwaystrue, cleanast, ematch_compile
using Metatheory: alwaystrue, cleanast, ematch_compile, match_compile

export @rule
export @theory
Expand Down Expand Up @@ -373,14 +373,17 @@ macro rule(args...)
end
ematcher_left_expr = esc(ematch_compile(lhs, ppvars, 1))

matcher_left_expr = match_compile(lhs, pvars)


if RuleType == DynamicRule
rhs_rewritten = rewrite_rhs(r)
rhs_consequent = makeconsequent(rhs_rewritten)
params = Expr(:tuple, :_lhs_expr, :_egraph, pvars...)
rhs = :($(esc(params)) -> $(esc(rhs_consequent)))
return quote
$(__source__)
DynamicRule($lhs, $rhs, $ematcher_left_expr, $(QuoteNode(rhs_consequent)))
DynamicRule($lhs, $rhs, $matcher_left_expr, $ematcher_left_expr, $(QuoteNode(rhs_consequent)))
end
end

Expand All @@ -393,7 +396,7 @@ macro rule(args...)

quote
$(__source__)
($RuleType)($lhs, $rhs, $ematcher_left_expr)
($RuleType)($lhs, $rhs, $matcher_left_expr, $ematcher_left_expr)
end
end

Expand Down Expand Up @@ -470,16 +473,24 @@ macro capture(args...)

pvars = Symbol[]
lhs = makepattern(lhs, pvars, slots, __module__)
bind = Expr(
:block,
map(key -> :($(esc(key)) = getindex(__MATCHES__, findfirst((==)($(QuoteNode(key))), $pvars))), pvars)...,
)
bind_exprs = Expr[]

for key in pvars
idx = findfirst((==)(key), pvars)
push!(bind_exprs, :($(esc(key)) = __MATCHES__[$idx]))
end

setdebrujin!(lhs, pvars)

matcher_left_expr = match_compile(lhs, pvars)


ret = quote
$(__source__)
rule = DynamicRule($lhs, (_lhs_expr, _egraph, pvars...) -> pvars, (x...) -> nothing)
rule = DynamicRule($lhs, (_lhs_expr, _egraph, pvars...) -> pvars, $matcher_left_expr, nothing)
__MATCHES__ = rule($(esc(ex)))
if __MATCHES__ !== nothing
$bind
if !isnothing(__MATCHES__)
$(bind_exprs...)
true
else
false
Expand Down
Loading
Loading