Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow local buffers #172

Merged
merged 7 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ include("../docstrings.jl")
using DataStructures
using TermInterface
using TimerOutputs
using Metatheory:
alwaystrue, cleanast, binarize, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
using Metatheory: alwaystrue, cleanast, binarize
using Metatheory.Patterns
using Metatheory.Rules
using Metatheory.EMatchCompiler
Expand Down
27 changes: 23 additions & 4 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

abstract type AbstractENode end

import Metatheory: maybelock!

const AnalysisData = NamedTuple{N,T} where {N,T<:Tuple}
const EClassId = Int64
const TermTypes = Dict{Tuple{Any,Int},Type}
# TODO document bindings
const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}}
const DEFAULT_BUFFER_SIZE = 1048576

struct ENodeLiteral <: AbstractENode
value
Expand Down Expand Up @@ -190,14 +195,21 @@ mutable struct EGraph
termtypes::TermTypes
numclasses::Int
numnodes::Int
"If we use global buffers we may need to lock. Defaults to true."
needslock::Bool
"Buffer for e-matching which defaults to a global. Use a local buffer for generated functions."
buffer::Vector{Bindings}
"Buffer for rule application which defaults to a global. Use a local buffer for generated functions."
merges_buffer::Vector{Tuple{Int,Int}}
lock::ReentrantLock
end


"""
EGraph(expr)
Construct an EGraph from a starting symbolic expression `expr`.
"""
function EGraph()
function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE)
EGraph(
IntDisjointSet(),
Dict{EClassId,EClass}(),
Expand All @@ -210,12 +222,19 @@ function EGraph()
TermTypes(),
0,
0,
# 0
needslock,
Bindings[],
Tuple{Int,Int}[],
ReentrantLock(),
)
end

function EGraph(e; keepmeta = false)
g = EGraph()
function maybelock!(f::Function, g::EGraph)
g.needslock ? lock(f, g.buffer_lock) : f()
end

function EGraph(e; keepmeta = false, kwargs...)
g = EGraph(kwargs...)
keepmeta && addanalysis!(g, :metadata_analysis)
g.root = addexpr!(g, e; keepmeta = keepmeta)
g
Expand Down
30 changes: 14 additions & 16 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ function eqsat_search!(
)::Int
n_matches = 0

lock(BUFFER_LOCK) do
empty!(BUFFER[])
maybelock!(g) do
empty!(g.buffer)
end

@debug "SEARCHING"
Expand Down Expand Up @@ -166,13 +166,13 @@ function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId
end

function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction)
push!(MERGES_BUF[], (id, instantiate_enode!(buf, g, rule.right)))
push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right)))
nothing
end

function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int)
pat_to_inst = direction == 1 ? rule.right : rule.left
push!(MERGES_BUF[], (id, instantiate_enode!(bindings, g, pat_to_inst)))
push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst)))
nothing
end

Expand Down Expand Up @@ -207,46 +207,44 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas
r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
isnothing(r) && return nothing
rcid = addexpr!(g, r)
push!(MERGES_BUF[], (id, rcid))
push!(g.merges_buffer, (id, rcid))
return nothing
end



function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::SaturationReport, params::SaturationParams)
i = 0
@assert isempty(MERGES_BUF[])
@assert isempty(g.merges_buffer)

@debug "APPLYING $(length(BUFFER[])) matches"
@debug "APPLYING $(length(g.buffer)) matches"
maybelock!(g) do
while !isempty(g.buffer)

lock(BUFFER_LOCK) do
while !isempty(BUFFER[])
if reached(g, params.goal)
@debug "Goal reached"
rep.reason = :goalreached
return
end

bindings = popfirst!(BUFFER[])
bindings = pop!(g.buffer)
rule_idx, id = bindings[0]
direction = sign(rule_idx)
rule_idx = abs(rule_idx)
rule = theory[rule_idx]


halt_reason = lock(MERGES_BUF_LOCK) do
apply_rule!(bindings, g, rule, id, direction)
end
halt_reason = apply_rule!(bindings, g, rule, id, direction)

if !isnothing(halt_reason)
rep.reason = halt_reason
return
end
end
end
lock(MERGES_BUF_LOCK) do
while !isempty(MERGES_BUF[])
(l, r) = popfirst!(MERGES_BUF[])
maybelock!(g) do
while !isempty(g.merges_buffer)
(l, r) = pop!(g.merges_buffer)
merge!(g, l, r)
end
end
Expand Down
19 changes: 1 addition & 18 deletions src/Metatheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,14 @@ module Metatheory

using DataStructures

import Base.ImmutableDict

const Bindings = ImmutableDict{Int,Tuple{Int,Int}}
const DEFAULT_BUFFER_SIZE = 1048576
const BUFFER = Ref(CircularDeque{Bindings}(DEFAULT_BUFFER_SIZE))
const BUFFER_LOCK = ReentrantLock()
const MERGES_BUF = Ref(CircularDeque{Tuple{Int,Int}}(DEFAULT_BUFFER_SIZE))
const MERGES_BUF_LOCK = ReentrantLock()

function resetbuffers!(bufsize)
BUFFER[] = CircularDeque{Bindings}(bufsize)
MERGES_BUF[] = CircularDeque{Tuple{Int,Int}}(bufsize)
end

function __init__()
resetbuffers!(DEFAULT_BUFFER_SIZE)
end

using Base.Meta
using Reexport
using TermInterface

@inline alwaystrue(x) = true

function lookup_pat end
function maybelock! end

include("docstrings.jl")
include("utils.jl")
Expand Down
42 changes: 21 additions & 21 deletions src/ematch_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module EMatchCompiler

using TermInterface
using ..Patterns
using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, LL
using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, LL, maybelock!

function ematcher(p::Any)
function literal_ematcher(next, g, data, bindings)
Expand Down Expand Up @@ -48,7 +48,7 @@ function predicate_ematcher(p::PatVar, pred)
end
end
end

function ematcher(p::PatVar)
pred_matcher = predicate_ematcher(p, p.predicate)

Expand Down Expand Up @@ -115,14 +115,14 @@ function ematcher(p::PatTerm)

for n in g[car(data)]
if canbindtop(n)
loop(LL(arguments(n),1), bindings, ematchers)
loop(LL(arguments(n), 1), bindings, ematchers)
end
end
end
end
end


const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int, Int}}()
const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}()

"""
Substitutions are efficiently represented in memory as vector of tuples of two integers.
Expand All @@ -137,30 +137,30 @@ The format is as follows
* The end of a substitution is delimited by (0,0)
"""
function ematcher_yield(p, npvars::Int, direction::Int)
em = ematcher(p)
function ematcher_yield(g, rule_idx, id)::Int
n_matches = 0
em(g, (id,), EMPTY_ECLASS_DICT) do b,n
lock(BUFFER_LOCK) do
push!(BUFFER[], assoc(b, 0, (rule_idx * direction, id)))
n_matches+=1
end
end
n_matches
em = ematcher(p)
function ematcher_yield(g, rule_idx, id)::Int
n_matches = 0
em(g, (id,), EMPTY_ECLASS_DICT) do b, n
maybelock!(g) do
push!(g.buffer, assoc(b, 0, (rule_idx * direction, id)))
n_matches += 1
end
end
n_matches
end
end

ematcher_yield(p,npvars) = ematcher_yield(p,npvars,1)
ematcher_yield(p, npvars) = ematcher_yield(p, npvars, 1)

function ematcher_yield_bidir(l, r, npvars::Int)
eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1)
function ematcher_yield_bidir(g, rule_idx, id)::Int
eml(g,rule_idx,id) + emr(g,rule_idx,id)
end
eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1)
function ematcher_yield_bidir(g, rule_idx, id)::Int
eml(g, rule_idx, id) + emr(g, rule_idx, id)
end
end

ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p")

export ematcher_yield, ematcher_yield_bidir

end
end
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ end
macro matchable(expr)
@assert expr.head == :struct
name = expr.args[2]
if name isa Expr
name.head === :(<:) && (name = name.args[1])
if name isa Expr
name.head === :(<:) && (name = name.args[1])
name isa Expr && name.head === :curly && (name = name.args[1])
end
fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args)
Expand Down