From d0aa5114ac6c64526ec71bfbb182ffc8531a9a3f Mon Sep 17 00:00:00 2001 From: Will Kimmerer Date: Sat, 21 Oct 2023 19:11:40 -0400 Subject: [PATCH 1/6] allow non-global buffers --- src/EGraphs/EGraphs.jl | 3 ++- src/EGraphs/saturation.jl | 54 ++++++++++++++++++++++++--------------- src/Metatheory.jl | 10 +++++--- src/ematch_compiler.jl | 15 ++++++----- src/utils.jl | 19 ++++++++++++++ 5 files changed, 69 insertions(+), 32 deletions(-) diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index d418c3db..600f8a6d 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -6,7 +6,8 @@ using DataStructures using TermInterface using TimerOutputs using Metatheory: - alwaystrue, cleanast, binarize, @log, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings + alwaystrue, cleanast, binarize, @log, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings, + lockmergesbuffer, lockbuffer using Metatheory.Patterns using Metatheory.Rules using Metatheory.EMatchCompiler diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 663b68e6..cd96b855 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -72,6 +72,10 @@ Base.@kwdef mutable struct SaturationParams threaded::Bool = false timer::Bool = true printiter::Bool = false + buffer::CircularDeque{Bindings} = BUFFER[] + buffer_lock::ReentrantLock = BUFFER_LOCK + merges_buffer::CircularDeque{Tuple{Int,Int}} = MERGES_BUF[] + merges_buffer_lock::ReentrantLock = MERGES_BUF_LOCK end # function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64} @@ -117,11 +121,12 @@ function eqsat_search!( theory::Vector{<:AbstractRule}, scheduler::AbstractScheduler, report::SaturationReport, + params::SaturationParams )::Int n_matches = 0 - lock(BUFFER_LOCK) do - empty!(BUFFER[]) + lockbuffer(params) do + empty!(params.buffer) end for (rule_idx, rule) in enumerate(theory) @@ -133,7 +138,7 @@ function eqsat_search!( ids = cached_ids(g, rule.left) rule isa BidirRule && (ids = ids ∪ cached_ids(g, rule.right)) for i in ids - n_matches += rule.ematcher!(g, rule_idx, i) + n_matches += rule.ematcher!(g, rule_idx, i, params) end inform!(scheduler, rule, n_matches) end @@ -163,19 +168,25 @@ function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId add!(g, ENodeTerm(eh, new_op, T, map(arg -> instantiate_enode!(bindings, g, arg), args))) end -function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction) - push!(MERGES_BUF[], (id, instantiate_enode!(buf, g, rule.right))) +function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction, params::SaturationParams) + push!(params.merges_buffer, (id, instantiate_enode!(buf, g, rule.right))) nothing end -function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int) +function apply_rule!( + bindings::Bindings, g::EGraph, rule::EqualityRule, + id::EClassId, direction::Int, params::SaturationParams +) pat_to_inst = direction == 1 ? rule.right : rule.left - push!(MERGES_BUF[], (id, instantiate_enode!(bindings, g, pat_to_inst))) + push!(params.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst))) nothing end -function apply_rule!(bindings::Bindings, g::EGraph, rule::UnequalRule, id::EClassId, direction::Int) +function apply_rule!( + bindings::Bindings, g::EGraph, + rule::UnequalRule, id::EClassId, direction::Int, params::SaturationParams +) pat_to_inst = direction == 1 ? rule.right : rule.left other_id = instantiate_enode!(bindings, g, pat_to_inst) @@ -200,12 +211,15 @@ function instantiate_actual_param!(bindings::Bindings, g::EGraph, i) return eclass end -function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClassId, direction::Int) +function apply_rule!( + bindings::Bindings, g::EGraph, rule::DynamicRule, + id::EClassId, direction::Int, params::SaturationParams +) f = rule.rhs_fun r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) isnothing(r) && return nothing rcid = addexpr!(g, r) - push!(MERGES_BUF[], (id, rcid)) + push!(params.merges_buffer, (id, rcid)) return nothing end @@ -213,25 +227,25 @@ end function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::SaturationReport, params::SaturationParams) i = 0 - @assert isempty(MERGES_BUF[]) + @assert isempty(params.merges_buffer) - lock(BUFFER_LOCK) do - while !isempty(BUFFER[]) + lockbuffer(params) do + while !isempty(params.buffer) if reached(g, params.goal) @log "Goal reached" rep.reason = :goalreached return end - bindings = popfirst!(BUFFER[]) + bindings = popfirst!(params.buffer) rule_idx, id = bindings[0] direction = sign(rule_idx) rule_idx = abs(rule_idx) rule = theory[rule_idx] - halt_reason = lock(MERGES_BUF_LOCK) do - apply_rule!(bindings, g, rule, id, direction) + halt_reason = lockmergesbuffer(params) do + apply_rule!(bindings, g, rule, id, direction, params) end if !isnothing(halt_reason) @@ -240,9 +254,9 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end end end - lock(MERGES_BUF_LOCK) do - while !isempty(MERGES_BUF[]) - (l, r) = popfirst!(MERGES_BUF[]) + lockmergesbuffer(params) do + while !isempty(params.merges_buffer) + (l, r) = popfirst!(params.merges_buffer) merge!(g, l, r) end end @@ -267,7 +281,7 @@ function eqsat_step!( setiter!(scheduler, curr_iter) - @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report) + @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report, params) @timeit report.to "Apply" eqsat_apply!(g, theory, report, params) diff --git a/src/Metatheory.jl b/src/Metatheory.jl index 29e09eef..fb8fe97e 100644 --- a/src/Metatheory.jl +++ b/src/Metatheory.jl @@ -4,16 +4,18 @@ using DataStructures import Base.ImmutableDict +createbuffer(::Type{T}, size) where T = CircularDeque{T}(size) + const Bindings = ImmutableDict{Int,Tuple{Int,Int}} const DEFAULT_BUFFER_SIZE = 1048576 -const BUFFER = Ref(CircularDeque{Bindings}(DEFAULT_BUFFER_SIZE)) +const BUFFER = Ref(createbuffer(Bindings, DEFAULT_BUFFER_SIZE)) const BUFFER_LOCK = ReentrantLock() -const MERGES_BUF = Ref(CircularDeque{Tuple{Int,Int}}(DEFAULT_BUFFER_SIZE)) +const MERGES_BUF = Ref(createbuffer(Tuple{Int,Int}, DEFAULT_BUFFER_SIZE)) const MERGES_BUF_LOCK = ReentrantLock() function resetbuffers!(bufsize) - BUFFER[] = CircularDeque{Bindings}(bufsize) - MERGES_BUF[] = CircularDeque{Tuple{Int,Int}}(bufsize) + BUFFER[] = createbuffer(Bindings, bufsize) + MERGES_BUF[] = createbuffer(Tuple{Int,Int}, bufsize) end function __init__() diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index 5aed17dc..7434a913 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -2,7 +2,8 @@ module EMatchCompiler using TermInterface using ..Patterns -using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, LL +using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, + DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, LL, lockbuffer, lockmergesbuffer function ematcher(p::Any) function literal_ematcher(next, g, data, bindings) @@ -138,11 +139,11 @@ The format is as follows """ function ematcher_yield(p, npvars::Int, direction::Int) em = ematcher(p) - function ematcher_yield(g, rule_idx, id)::Int + function ematcher_yield(g, rule_idx, id, params)::Int n_matches = 0 em(g, (id,), EMPTY_ECLASS_DICT) do b,n - lock(BUFFER_LOCK) do - push!(BUFFER[], assoc(b, 0, (rule_idx * direction, id))) + lockbuffer(params) do + push!(params.buffer, assoc(b, 0, (rule_idx * direction, id))) n_matches+=1 end end @@ -154,8 +155,8 @@ ematcher_yield(p,npvars) = ematcher_yield(p,npvars,1) function ematcher_yield_bidir(l, r, npvars::Int) eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1) - function ematcher_yield_bidir(g, rule_idx, id)::Int - eml(g,rule_idx,id) + emr(g,rule_idx,id) + function ematcher_yield_bidir(g, rule_idx, id, params)::Int + eml(g,rule_idx,id, params) + emr(g,rule_idx,id, params) end end @@ -163,4 +164,4 @@ ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p") export ematcher_yield, ematcher_yield_bidir -end \ No newline at end of file +end diff --git a/src/utils.jl b/src/utils.jl index 12cc9837..bf043404 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,24 @@ using Base: ImmutableDict +function lockbuffer(f, params) + if params.threaded + lock(params.buffer_lock) do + return f() + end + else + return f() + end +end +function lockmergesbuffer(f, params) + if params.threaded + lock(params.merges_buffer_lock) do + return f() + end + else + return f() + end +end + function binarize(e::T) where {T} !istree(e) && return e head = exprhead(e) From a675952e0c1b78bac5379ae6b47cda82af5c2984 Mon Sep 17 00:00:00 2001 From: Will Kimmerer Date: Sat, 21 Oct 2023 19:30:44 -0400 Subject: [PATCH 2/6] Two docstrings and a default parameter. --- src/EGraphs/saturation.jl | 2 ++ src/Metatheory.jl | 11 ++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index cd96b855..60c2d20b 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -72,8 +72,10 @@ Base.@kwdef mutable struct SaturationParams threaded::Bool = false timer::Bool = true printiter::Bool = false + "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." buffer::CircularDeque{Bindings} = BUFFER[] buffer_lock::ReentrantLock = BUFFER_LOCK + "Buffer for rule application which defaults to a global. Use a local buffer for generated functions." merges_buffer::CircularDeque{Tuple{Int,Int}} = MERGES_BUF[] merges_buffer_lock::ReentrantLock = MERGES_BUF_LOCK end diff --git a/src/Metatheory.jl b/src/Metatheory.jl index fb8fe97e..3f3fca79 100644 --- a/src/Metatheory.jl +++ b/src/Metatheory.jl @@ -4,22 +4,23 @@ using DataStructures import Base.ImmutableDict -createbuffer(::Type{T}, size) where T = CircularDeque{T}(size) +createbuffer(::Type{T}, size = DEFAULT_BUFFER_SIZE) where T = + CircularDeque{T}(size) const Bindings = ImmutableDict{Int,Tuple{Int,Int}} const DEFAULT_BUFFER_SIZE = 1048576 -const BUFFER = Ref(createbuffer(Bindings, DEFAULT_BUFFER_SIZE)) +const BUFFER = Ref(createbuffer(Bindings)) const BUFFER_LOCK = ReentrantLock() -const MERGES_BUF = Ref(createbuffer(Tuple{Int,Int}, DEFAULT_BUFFER_SIZE)) +const MERGES_BUF = Ref(createbuffer(Tuple{Int,Int})) const MERGES_BUF_LOCK = ReentrantLock() -function resetbuffers!(bufsize) +function resetbuffers!(bufsize = DEFAULT_BUFFER_SIZE) BUFFER[] = createbuffer(Bindings, bufsize) MERGES_BUF[] = createbuffer(Tuple{Int,Int}, bufsize) end function __init__() - resetbuffers!(DEFAULT_BUFFER_SIZE) + resetbuffers!() end using Base.Meta From a3074666d6aa31ec92d70c04bfb42e2cf11f5e2c Mon Sep 17 00:00:00 2001 From: Will Kimmerer Date: Sat, 21 Oct 2023 19:37:40 -0400 Subject: [PATCH 3/6] rename and remove imports --- src/EGraphs/EGraphs.jl | 4 ++-- src/EGraphs/saturation.jl | 8 ++++---- src/ematch_compiler.jl | 4 ++-- src/utils.jl | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index 600f8a6d..8fcb1833 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -6,8 +6,8 @@ using DataStructures using TermInterface using TimerOutputs using Metatheory: - alwaystrue, cleanast, binarize, @log, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings, - lockmergesbuffer, lockbuffer + alwaystrue, cleanast, binarize, @log, Bindings, + lockmergesbuffer!, lockbuffer! using Metatheory.Patterns using Metatheory.Rules using Metatheory.EMatchCompiler diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 60c2d20b..4fe1c9cb 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -127,7 +127,7 @@ function eqsat_search!( )::Int n_matches = 0 - lockbuffer(params) do + lockbuffer!(params) do empty!(params.buffer) end @@ -231,7 +231,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation i = 0 @assert isempty(params.merges_buffer) - lockbuffer(params) do + lockbuffer!(params) do while !isempty(params.buffer) if reached(g, params.goal) @log "Goal reached" @@ -246,7 +246,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation rule = theory[rule_idx] - halt_reason = lockmergesbuffer(params) do + halt_reason = lockmergesbuffer!(params) do apply_rule!(bindings, g, rule, id, direction, params) end @@ -256,7 +256,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end end end - lockmergesbuffer(params) do + lockmergesbuffer!(params) do while !isempty(params.merges_buffer) (l, r) = popfirst!(params.merges_buffer) merge!(g, l, r) diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index 7434a913..eb6e3d7f 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -3,7 +3,7 @@ module EMatchCompiler using TermInterface using ..Patterns using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, - DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, LL, lockbuffer, lockmergesbuffer + LL, lockbuffer!, lockmergesbuffer! function ematcher(p::Any) function literal_ematcher(next, g, data, bindings) @@ -142,7 +142,7 @@ function ematcher_yield(p, npvars::Int, direction::Int) function ematcher_yield(g, rule_idx, id, params)::Int n_matches = 0 em(g, (id,), EMPTY_ECLASS_DICT) do b,n - lockbuffer(params) do + lockbuffer!(params) do push!(params.buffer, assoc(b, 0, (rule_idx * direction, id))) n_matches+=1 end diff --git a/src/utils.jl b/src/utils.jl index bf043404..516adb93 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,6 @@ using Base: ImmutableDict -function lockbuffer(f, params) +function lockbuffer!(f, params) if params.threaded lock(params.buffer_lock) do return f() @@ -9,7 +9,7 @@ function lockbuffer(f, params) return f() end end -function lockmergesbuffer(f, params) +function lockmergesbuffer!(f, params) if params.threaded lock(params.merges_buffer_lock) do return f() From 8417ec33af570bd4d2ca3ef1411d1c18cf72de6f Mon Sep 17 00:00:00 2001 From: Will Kimmerer Date: Sat, 21 Oct 2023 19:42:11 -0400 Subject: [PATCH 4/6] add back imports where needed --- src/EGraphs/EGraphs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index 8fcb1833..cb30e509 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -7,7 +7,7 @@ using TermInterface using TimerOutputs using Metatheory: alwaystrue, cleanast, binarize, @log, Bindings, - lockmergesbuffer!, lockbuffer! + lockmergesbuffer!, lockbuffer!, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK using Metatheory.Patterns using Metatheory.Rules using Metatheory.EMatchCompiler From df26847a8931934c02c98fed5d0e61d903a7235d Mon Sep 17 00:00:00 2001 From: Will Kimmerer Date: Tue, 24 Oct 2023 14:00:10 -0400 Subject: [PATCH 5/6] move buffers into the egraph. --- src/EGraphs/egraph.jl | 21 ++++++++++++++--- src/EGraphs/saturation.jl | 49 +++++++++++++++++---------------------- src/ematch_compiler.jl | 10 ++++---- src/utils.jl | 12 +++++----- 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 981a81fe..29bc42bf 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -190,6 +190,14 @@ mutable struct EGraph termtypes::TermTypes numclasses::Int numnodes::Int + "If we use global buffers we may need to lock. Defaults to true." + needslock::Bool + "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." + buffer::CircularDeque{Bindings} + buffer_lock::ReentrantLock + "Buffer for rule application which defaults to a global. Use a local buffer for generated functions." + merges_buffer::CircularDeque{Tuple{Int,Int}} + merges_buffer_lock::ReentrantLock end @@ -197,7 +205,13 @@ end EGraph(expr) Construct an EGraph from a starting symbolic expression `expr`. """ -function EGraph() +function EGraph(; + needslock::Bool = true, + buffer::CircularDeque{Bindings} = BUFFER[], + buffer_lock::ReentrantLock = BUFFER_LOCK, + merges_buffer::CircularDeque{Tuple{Int,Int}} = MERGES_BUF[], + merges_buffer_lock::ReentrantLock = MERGES_BUF_LOCK + ) EGraph( IntDisjointSet(), Dict{EClassId,EClass}(), @@ -211,11 +225,12 @@ function EGraph() 0, 0, # 0 + needslock, buffer, buffer_lock, merges_buffer, merges_buffer_lock ) end -function EGraph(e; keepmeta = false) - g = EGraph() +function EGraph(e; keepmeta = false, kwargs...) + g = EGraph(kwargs...) keepmeta && addanalysis!(g, :metadata_analysis) g.root = addexpr!(g, e; keepmeta = keepmeta) g diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 4fe1c9cb..6a052a4b 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -72,12 +72,6 @@ Base.@kwdef mutable struct SaturationParams threaded::Bool = false timer::Bool = true printiter::Bool = false - "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." - buffer::CircularDeque{Bindings} = BUFFER[] - buffer_lock::ReentrantLock = BUFFER_LOCK - "Buffer for rule application which defaults to a global. Use a local buffer for generated functions." - merges_buffer::CircularDeque{Tuple{Int,Int}} = MERGES_BUF[] - merges_buffer_lock::ReentrantLock = MERGES_BUF_LOCK end # function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64} @@ -122,13 +116,12 @@ function eqsat_search!( g::EGraph, theory::Vector{<:AbstractRule}, scheduler::AbstractScheduler, - report::SaturationReport, - params::SaturationParams + report::SaturationReport )::Int n_matches = 0 - lockbuffer!(params) do - empty!(params.buffer) + lockbuffer!(g) do + empty!(g.buffer) end for (rule_idx, rule) in enumerate(theory) @@ -140,7 +133,7 @@ function eqsat_search!( ids = cached_ids(g, rule.left) rule isa BidirRule && (ids = ids ∪ cached_ids(g, rule.right)) for i in ids - n_matches += rule.ematcher!(g, rule_idx, i, params) + n_matches += rule.ematcher!(g, rule_idx, i) end inform!(scheduler, rule, n_matches) end @@ -170,24 +163,24 @@ function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId add!(g, ENodeTerm(eh, new_op, T, map(arg -> instantiate_enode!(bindings, g, arg), args))) end -function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction, params::SaturationParams) - push!(params.merges_buffer, (id, instantiate_enode!(buf, g, rule.right))) +function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction) + push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right))) nothing end function apply_rule!( bindings::Bindings, g::EGraph, rule::EqualityRule, - id::EClassId, direction::Int, params::SaturationParams + id::EClassId, direction::Int ) pat_to_inst = direction == 1 ? rule.right : rule.left - push!(params.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst))) + push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst))) nothing end function apply_rule!( bindings::Bindings, g::EGraph, - rule::UnequalRule, id::EClassId, direction::Int, params::SaturationParams + rule::UnequalRule, id::EClassId, direction::Int ) pat_to_inst = direction == 1 ? rule.right : rule.left other_id = instantiate_enode!(bindings, g, pat_to_inst) @@ -215,13 +208,13 @@ end function apply_rule!( bindings::Bindings, g::EGraph, rule::DynamicRule, - id::EClassId, direction::Int, params::SaturationParams + id::EClassId, direction::Int ) f = rule.rhs_fun r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) isnothing(r) && return nothing rcid = addexpr!(g, r) - push!(params.merges_buffer, (id, rcid)) + push!(g.merges_buffer, (id, rcid)) return nothing end @@ -229,25 +222,25 @@ end function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::SaturationReport, params::SaturationParams) i = 0 - @assert isempty(params.merges_buffer) + @assert isempty(g.merges_buffer) - lockbuffer!(params) do - while !isempty(params.buffer) + lockbuffer!(g) do + while !isempty(g.buffer) if reached(g, params.goal) @log "Goal reached" rep.reason = :goalreached return end - bindings = popfirst!(params.buffer) + bindings = popfirst!(g.buffer) rule_idx, id = bindings[0] direction = sign(rule_idx) rule_idx = abs(rule_idx) rule = theory[rule_idx] - halt_reason = lockmergesbuffer!(params) do - apply_rule!(bindings, g, rule, id, direction, params) + halt_reason = lockmergesbuffer!(g) do + apply_rule!(bindings, g, rule, id, direction) end if !isnothing(halt_reason) @@ -256,9 +249,9 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end end end - lockmergesbuffer!(params) do - while !isempty(params.merges_buffer) - (l, r) = popfirst!(params.merges_buffer) + lockmergesbuffer!(g) do + while !isempty(g.merges_buffer) + (l, r) = popfirst!(g.merges_buffer) merge!(g, l, r) end end @@ -283,7 +276,7 @@ function eqsat_step!( setiter!(scheduler, curr_iter) - @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report, params) + @timeit report.to "Search" eqsat_search!(g, theory, scheduler, report) @timeit report.to "Apply" eqsat_apply!(g, theory, report, params) diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index eb6e3d7f..e890dfbd 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -139,11 +139,11 @@ The format is as follows """ function ematcher_yield(p, npvars::Int, direction::Int) em = ematcher(p) - function ematcher_yield(g, rule_idx, id, params)::Int + function ematcher_yield(g, rule_idx, id)::Int n_matches = 0 em(g, (id,), EMPTY_ECLASS_DICT) do b,n - lockbuffer!(params) do - push!(params.buffer, assoc(b, 0, (rule_idx * direction, id))) + lockbuffer!(g) do + push!(g.buffer, assoc(b, 0, (rule_idx * direction, id))) n_matches+=1 end end @@ -155,8 +155,8 @@ ematcher_yield(p,npvars) = ematcher_yield(p,npvars,1) function ematcher_yield_bidir(l, r, npvars::Int) eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1) - function ematcher_yield_bidir(g, rule_idx, id, params)::Int - eml(g,rule_idx,id, params) + emr(g,rule_idx,id, params) + function ematcher_yield_bidir(g, rule_idx, id)::Int + eml(g,rule_idx,id) + emr(g,rule_idx,id) end end diff --git a/src/utils.jl b/src/utils.jl index 516adb93..89049d7a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,17 +1,17 @@ using Base: ImmutableDict -function lockbuffer!(f, params) - if params.threaded - lock(params.buffer_lock) do +function lockbuffer!(f, graph) + if graph.needslock + lock(graph.buffer_lock) do return f() end else return f() end end -function lockmergesbuffer!(f, params) - if params.threaded - lock(params.merges_buffer_lock) do +function lockmergesbuffer!(f, graph) + if graph.needslock + lock(graph.merges_buffer_lock) do return f() end else From 19394a421b3faef7e0b34ed34b4a98d6a6bac5f1 Mon Sep 17 00:00:00 2001 From: a Date: Sat, 28 Oct 2023 15:24:09 +0200 Subject: [PATCH 6/6] use local vector buffers --- src/EGraphs/EGraphs.jl | 3 +-- src/EGraphs/egraph.jl | 30 +++++++++++++++------------- src/EGraphs/saturation.jl | 31 ++++++++++------------------- src/Metatheory.jl | 22 +-------------------- src/ematch_compiler.jl | 41 +++++++++++++++++++-------------------- src/utils.jl | 23 ++-------------------- 6 files changed, 51 insertions(+), 99 deletions(-) diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index 1468945f..1a1bdc6a 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -5,8 +5,7 @@ include("../docstrings.jl") using DataStructures using TermInterface using TimerOutputs -using Metatheory: - alwaystrue, cleanast, binarize, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings +using Metatheory: alwaystrue, cleanast, binarize using Metatheory.Patterns using Metatheory.Rules using Metatheory.EMatchCompiler diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 165e015e..c989b0da 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -4,9 +4,14 @@ abstract type AbstractENode end +import Metatheory: maybelock! + const AnalysisData = NamedTuple{N,T} where {N,T<:Tuple} const EClassId = Int64 const TermTypes = Dict{Tuple{Any,Int},Type} +# TODO document bindings +const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}} +const DEFAULT_BUFFER_SIZE = 1048576 struct ENodeLiteral <: AbstractENode value @@ -193,11 +198,10 @@ mutable struct EGraph "If we use global buffers we may need to lock. Defaults to true." needslock::Bool "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." - buffer::CircularDeque{Bindings} - buffer_lock::ReentrantLock + buffer::Vector{Bindings} "Buffer for rule application which defaults to a global. Use a local buffer for generated functions." - merges_buffer::CircularDeque{Tuple{Int,Int}} - merges_buffer_lock::ReentrantLock + merges_buffer::Vector{Tuple{Int,Int}} + lock::ReentrantLock end @@ -205,13 +209,7 @@ end EGraph(expr) Construct an EGraph from a starting symbolic expression `expr`. """ -function EGraph(; - needslock::Bool = true, - buffer::CircularDeque{Bindings} = BUFFER[], - buffer_lock::ReentrantLock = BUFFER_LOCK, - merges_buffer::CircularDeque{Tuple{Int,Int}} = MERGES_BUF[], - merges_buffer_lock::ReentrantLock = MERGES_BUF_LOCK - ) +function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE) EGraph( IntDisjointSet(), Dict{EClassId,EClass}(), @@ -224,11 +222,17 @@ function EGraph(; TermTypes(), 0, 0, - # 0 - needslock, buffer, buffer_lock, merges_buffer, merges_buffer_lock + needslock, + Bindings[], + Tuple{Int,Int}[], + ReentrantLock(), ) end +function maybelock!(f::Function, g::EGraph) + g.needslock ? lock(f, g.buffer_lock) : f() +end + function EGraph(e; keepmeta = false, kwargs...) g = EGraph(kwargs...) keepmeta && addanalysis!(g, :metadata_analysis) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index bb376ed6..da7dc906 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -115,11 +115,11 @@ function eqsat_search!( g::EGraph, theory::Vector{<:AbstractRule}, scheduler::AbstractScheduler, - report::SaturationReport + report::SaturationReport, )::Int n_matches = 0 - lockbuffer!(g) do + maybelock!(g) do empty!(g.buffer) end @@ -170,20 +170,14 @@ function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction) nothing end -function apply_rule!( - bindings::Bindings, g::EGraph, rule::EqualityRule, - id::EClassId, direction::Int -) +function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int) pat_to_inst = direction == 1 ? rule.right : rule.left push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst))) nothing end -function apply_rule!( - bindings::Bindings, g::EGraph, - rule::UnequalRule, id::EClassId, direction::Int -) +function apply_rule!(bindings::Bindings, g::EGraph, rule::UnequalRule, id::EClassId, direction::Int) pat_to_inst = direction == 1 ? rule.right : rule.left other_id = instantiate_enode!(bindings, g, pat_to_inst) @@ -208,10 +202,7 @@ function instantiate_actual_param!(bindings::Bindings, g::EGraph, i) return eclass end -function apply_rule!( - bindings::Bindings, g::EGraph, rule::DynamicRule, - id::EClassId, direction::Int -) +function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClassId, direction::Int) f = rule.rhs_fun r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) isnothing(r) && return nothing @@ -227,7 +218,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation @assert isempty(g.merges_buffer) @debug "APPLYING $(length(g.buffer)) matches" - lockbuffer!(g) do + maybelock!(g) do while !isempty(g.buffer) if reached(g, params.goal) @@ -236,16 +227,14 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation return end - bindings = popfirst!(g.buffer) + bindings = pop!(g.buffer) rule_idx, id = bindings[0] direction = sign(rule_idx) rule_idx = abs(rule_idx) rule = theory[rule_idx] - halt_reason = lockmergesbuffer!(g) do - apply_rule!(bindings, g, rule, id, direction) - end + halt_reason = apply_rule!(bindings, g, rule, id, direction) if !isnothing(halt_reason) rep.reason = halt_reason @@ -253,9 +242,9 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end end end - lockmergesbuffer!(g) do + maybelock!(g) do while !isempty(g.merges_buffer) - (l, r) = popfirst!(g.merges_buffer) + (l, r) = pop!(g.merges_buffer) merge!(g, l, r) end end diff --git a/src/Metatheory.jl b/src/Metatheory.jl index c7806174..6ab2a811 100644 --- a/src/Metatheory.jl +++ b/src/Metatheory.jl @@ -2,27 +2,6 @@ module Metatheory using DataStructures -import Base.ImmutableDict - -createbuffer(::Type{T}, size = DEFAULT_BUFFER_SIZE) where T = - CircularDeque{T}(size) - -const Bindings = ImmutableDict{Int,Tuple{Int,Int}} -const DEFAULT_BUFFER_SIZE = 1048576 -const BUFFER = Ref(createbuffer(Bindings)) -const BUFFER_LOCK = ReentrantLock() -const MERGES_BUF = Ref(createbuffer(Tuple{Int,Int})) -const MERGES_BUF_LOCK = ReentrantLock() - -function resetbuffers!(bufsize = DEFAULT_BUFFER_SIZE) - BUFFER[] = createbuffer(Bindings, bufsize) - MERGES_BUF[] = createbuffer(Tuple{Int,Int}, bufsize) -end - -function __init__() - resetbuffers!() -end - using Base.Meta using Reexport using TermInterface @@ -30,6 +9,7 @@ using TermInterface @inline alwaystrue(x) = true function lookup_pat end +function maybelock! end include("docstrings.jl") include("utils.jl") diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index e890dfbd..ea092dd3 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -2,8 +2,7 @@ module EMatchCompiler using TermInterface using ..Patterns -using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, - LL, lockbuffer!, lockmergesbuffer! +using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, LL, maybelock! function ematcher(p::Any) function literal_ematcher(next, g, data, bindings) @@ -49,7 +48,7 @@ function predicate_ematcher(p::PatVar, pred) end end end - + function ematcher(p::PatVar) pred_matcher = predicate_ematcher(p, p.predicate) @@ -116,14 +115,14 @@ function ematcher(p::PatTerm) for n in g[car(data)] if canbindtop(n) - loop(LL(arguments(n),1), bindings, ematchers) + loop(LL(arguments(n), 1), bindings, ematchers) end end end -end +end -const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int, Int}}() +const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}() """ Substitutions are efficiently represented in memory as vector of tuples of two integers. @@ -138,26 +137,26 @@ The format is as follows * The end of a substitution is delimited by (0,0) """ function ematcher_yield(p, npvars::Int, direction::Int) - em = ematcher(p) - function ematcher_yield(g, rule_idx, id)::Int - n_matches = 0 - em(g, (id,), EMPTY_ECLASS_DICT) do b,n - lockbuffer!(g) do - push!(g.buffer, assoc(b, 0, (rule_idx * direction, id))) - n_matches+=1 - end - end - n_matches + em = ematcher(p) + function ematcher_yield(g, rule_idx, id)::Int + n_matches = 0 + em(g, (id,), EMPTY_ECLASS_DICT) do b, n + maybelock!(g) do + push!(g.buffer, assoc(b, 0, (rule_idx * direction, id))) + n_matches += 1 + end end + n_matches + end end -ematcher_yield(p,npvars) = ematcher_yield(p,npvars,1) +ematcher_yield(p, npvars) = ematcher_yield(p, npvars, 1) function ematcher_yield_bidir(l, r, npvars::Int) - eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1) - function ematcher_yield_bidir(g, rule_idx, id)::Int - eml(g,rule_idx,id) + emr(g,rule_idx,id) - end + eml, emr = ematcher_yield(l, npvars, 1), ematcher_yield(r, npvars, -1) + function ematcher_yield_bidir(g, rule_idx, id)::Int + eml(g, rule_idx, id) + emr(g, rule_idx, id) + end end ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p") diff --git a/src/utils.jl b/src/utils.jl index 89049d7a..8e627165 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,24 +1,5 @@ using Base: ImmutableDict -function lockbuffer!(f, graph) - if graph.needslock - lock(graph.buffer_lock) do - return f() - end - else - return f() - end -end -function lockmergesbuffer!(f, graph) - if graph.needslock - lock(graph.merges_buffer_lock) do - return f() - end - else - return f() - end -end - function binarize(e::T) where {T} !istree(e) && return e head = exprhead(e) @@ -170,8 +151,8 @@ end macro matchable(expr) @assert expr.head == :struct name = expr.args[2] - if name isa Expr - name.head === :(<:) && (name = name.args[1]) + if name isa Expr + name.head === :(<:) && (name = name.args[1]) name isa Expr && name.head === :curly && (name = name.args[1]) end fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args)