Skip to content

Commit

Permalink
Merge pull request #172 from JuliaSymbolics/ale/localbuffers
Browse files Browse the repository at this point in the history
Allow local buffers
  • Loading branch information
0x0f0f0f authored Oct 28, 2023
2 parents 7828344 + ce871dd commit b10998a
Showing 6 changed files with 62 additions and 63 deletions.
3 changes: 1 addition & 2 deletions src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
@@ -5,8 +5,7 @@ include("../docstrings.jl")
using DataStructures
using TermInterface
using TimerOutputs
using Metatheory:
alwaystrue, cleanast, binarize, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
using Metatheory: alwaystrue, cleanast, binarize
using Metatheory.Patterns
using Metatheory.Rules
using Metatheory.EMatchCompiler
27 changes: 23 additions & 4 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
@@ -4,9 +4,14 @@

abstract type AbstractENode end

import Metatheory: maybelock!

const AnalysisData = NamedTuple{N,T} where {N,T<:Tuple}
const EClassId = Int64
const TermTypes = Dict{Tuple{Any,Int},Type}
# TODO document bindings
const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}}
const DEFAULT_BUFFER_SIZE = 1048576

struct ENodeLiteral <: AbstractENode
value
@@ -190,14 +195,21 @@ mutable struct EGraph
termtypes::TermTypes
numclasses::Int
numnodes::Int
"If we use global buffers we may need to lock. Defaults to true."
needslock::Bool
"Buffer for e-matching which defaults to a global. Use a local buffer for generated functions."
buffer::Vector{Bindings}
"Buffer for rule application which defaults to a global. Use a local buffer for generated functions."
merges_buffer::Vector{Tuple{Int,Int}}
lock::ReentrantLock
end


"""
EGraph(expr)
Construct an EGraph from a starting symbolic expression `expr`.
"""
function EGraph()
function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE)
EGraph(
IntDisjointSet(),
Dict{EClassId,EClass}(),
@@ -210,12 +222,19 @@ function EGraph()
TermTypes(),
0,
0,
# 0
needslock,
Bindings[],
Tuple{Int,Int}[],
ReentrantLock(),
)
end

function EGraph(e; keepmeta = false)
g = EGraph()
function maybelock!(f::Function, g::EGraph)
g.needslock ? lock(f, g.buffer_lock) : f()
end

function EGraph(e; keepmeta = false, kwargs...)
g = EGraph(kwargs...)
keepmeta && addanalysis!(g, :metadata_analysis)
g.root = addexpr!(g, e; keepmeta = keepmeta)
g
30 changes: 14 additions & 16 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
@@ -119,8 +119,8 @@ function eqsat_search!(
)::Int
n_matches = 0

lock(BUFFER_LOCK) do
empty!(BUFFER[])
maybelock!(g) do
empty!(g.buffer)
end

@debug "SEARCHING"
@@ -166,13 +166,13 @@ function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId
end

function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction)
push!(MERGES_BUF[], (id, instantiate_enode!(buf, g, rule.right)))
push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right)))
nothing
end

function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int)
pat_to_inst = direction == 1 ? rule.right : rule.left
push!(MERGES_BUF[], (id, instantiate_enode!(bindings, g, pat_to_inst)))
push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst)))
nothing
end

@@ -207,46 +207,44 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas
r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
isnothing(r) && return nothing
rcid = addexpr!(g, r)
push!(MERGES_BUF[], (id, rcid))
push!(g.merges_buffer, (id, rcid))
return nothing
end



function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::SaturationReport, params::SaturationParams)
i = 0
@assert isempty(MERGES_BUF[])
@assert isempty(g.merges_buffer)

@debug "APPLYING $(length(BUFFER[])) matches"
@debug "APPLYING $(length(g.buffer)) matches"
maybelock!(g) do
while !isempty(g.buffer)

lock(BUFFER_LOCK) do
while !isempty(BUFFER[])
if reached(g, params.goal)
@debug "Goal reached"
rep.reason = :goalreached
return
end

bindings = popfirst!(BUFFER[])
bindings = pop!(g.buffer)
rule_idx, id = bindings[0]
direction = sign(rule_idx)
rule_idx = abs(rule_idx)
rule = theory[rule_idx]


halt_reason = lock(MERGES_BUF_LOCK) do
apply_rule!(bindings, g, rule, id, direction)
end
halt_reason = apply_rule!(bindings, g, rule, id, direction)

if !isnothing(halt_reason)
rep.reason = halt_reason
return
end
end
end
lock(MERGES_BUF_LOCK) do
while !isempty(MERGES_BUF[])
(l, r) = popfirst!(MERGES_BUF[])
maybelock!(g) do
while !isempty(g.merges_buffer)
(l, r) = pop!(g.merges_buffer)
merge!(g, l, r)
end
end
19 changes: 1 addition & 18 deletions src/Metatheory.jl
Original file line number Diff line number Diff line change
@@ -2,31 +2,14 @@ module Metatheory

using DataStructures

import Base.ImmutableDict

const Bindings = ImmutableDict{Int,Tuple{Int,Int}}
const DEFAULT_BUFFER_SIZE = 1048576
const BUFFER = Ref(CircularDeque{Bindings}(DEFAULT_BUFFER_SIZE))
const BUFFER_LOCK = ReentrantLock()
const MERGES_BUF = Ref(CircularDeque{Tuple{Int,Int}}(DEFAULT_BUFFER_SIZE))
const MERGES_BUF_LOCK = ReentrantLock()

function resetbuffers!(bufsize)
BUFFER[] = CircularDeque{Bindings}(bufsize)
MERGES_BUF[] = CircularDeque{Tuple{Int,Int}}(bufsize)
end

function __init__()
resetbuffers!(DEFAULT_BUFFER_SIZE)
end

using Base.Meta
using Reexport
using TermInterface

@inline alwaystrue(x) = true

function lookup_pat end
function maybelock! end

include("docstrings.jl")
include("utils.jl")
42 changes: 21 additions & 21 deletions src/ematch_compiler.jl
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ module EMatchCompiler

using TermInterface
using ..Patterns
using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, LL
using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, LL, maybelock!

function ematcher(p::Any)
function literal_ematcher(next, g, data, bindings)
@@ -48,7 +48,7 @@ function predicate_ematcher(p::PatVar, pred)
end
end
end

function ematcher(p::PatVar)
pred_matcher = predicate_ematcher(p, p.predicate)

@@ -115,14 +115,14 @@ function ematcher(p::PatTerm)

for n in g[car(data)]
if canbindtop(n)
loop(LL(arguments(n),1), bindings, ematchers)
loop(LL(arguments(n), 1), bindings, ematchers)
end
end
end
end
end


const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int, Int}}()
const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}()

"""
Substitutions are efficiently represented in memory as vector of tuples of two integers.
@@ -137,30 +137,30 @@ The format is as follows
* The end of a substitution is delimited by (0,0)
"""
function ematcher_yield(p, npvars::Int, direction::Int)
em = ematcher(p)
function ematcher_yield(g, rule_idx, id)::Int
n_matches = 0
em(g, (id,), EMPTY_ECLASS_DICT) do b,n
lock(BUFFER_LOCK) do
push!(BUFFER[], assoc(b, 0, (rule_idx * direction, id)))
n_matches+=1
end
end
n_matches
em = ematcher(p)
function ematcher_yield(g, rule_idx, id)::Int
n_matches = 0
em(g, (id,), EMPTY_ECLASS_DICT) do b, n
maybelock!(g) do
push!(g.buffer, assoc(b, 0, (rule_idx * direction, id)))
n_matches += 1
end
end
n_matches
end
end

ematcher_yield(p,npvars) = ematcher_yield(p,npvars,1)
ematcher_yield(p, npvars) = ematcher_yield(p, npvars, 1)

function ematcher_yield_bidir(l, r, npvars::Int)
eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1)
function ematcher_yield_bidir(g, rule_idx, id)::Int
eml(g,rule_idx,id) + emr(g,rule_idx,id)
end
eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1)
function ematcher_yield_bidir(g, rule_idx, id)::Int
eml(g, rule_idx, id) + emr(g, rule_idx, id)
end
end

ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p")

export ematcher_yield, ematcher_yield_bidir

end
end
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -151,8 +151,8 @@ end
macro matchable(expr)
@assert expr.head == :struct
name = expr.args[2]
if name isa Expr
name.head === :(<:) && (name = name.args[1])
if name isa Expr
name.head === :(<:) && (name = name.args[1])
name isa Expr && name.head === :curly && (name = name.args[1])
end
fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args)

0 comments on commit b10998a

Please sign in to comment.