Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change implementation of BackoffScheduler to match egg. #249

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ simpl1_math = :(a + b + (0 * c) + d)
SUITE["basic_maths"]["simpl1"] = @benchmarkable (@assert :(a + b + d) == simplify(
$simpl1_math,
$maths_theory,
$(SaturationParams(; timer = false)),
$(SaturationParams(enodelimit=15000, timeout=8, timer = false)),
postprocess_maths,
))

simpl2_math = :(0 + (1 * foo) * 0 + (a * 0) + a)
SUITE["basic_maths"]["simpl2"] =
@benchmarkable (@assert :a == simplify($simpl2_math, $maths_theory, $(SaturationParams()), postprocess_maths))
@benchmarkable (@assert :a == simplify($simpl2_math, $maths_theory, $(SaturationParams(enodelimit=15000, timeout=8, timer=false)), postprocess_maths))


# ==================================================================
Expand All @@ -60,7 +60,7 @@ ex_orig = :(((p ⟹ q) && (r ⟹ s) && (p || r)) ⟹ (q || s))
ex_logic = rewrite(ex_orig, impl)

SUITE["prop_logic"]["rewrite"] = @benchmarkable rewrite($ex_orig, $impl)
SUITE["prop_logic"]["prove1"] = @benchmarkable (@assert prove($propositional_logic_theory, $ex_logic, 3, 6))
SUITE["prop_logic"]["prove1"] = @benchmarkable (@assert prove($propositional_logic_theory, $ex_logic, 2, 6))

ex_demorgan = :(!(p || q) == (!p && !q))
SUITE["prop_logic"]["demorgan"] = @benchmarkable (@assert prove($propositional_logic_theory, $ex_demorgan))
Expand Down
5 changes: 1 addition & 4 deletions examples/prove.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sketch function for basic iterative saturation and extraction
# Sketch function for basic iterative saturation and extraction
function prove(
t,
ex,
Expand All @@ -19,9 +19,6 @@ function prove(
params.goal = (g::EGraph) -> in_same_class(g, ids...)
saturate!(g, t, params)
ex = extract!(g, astsize)
if !TermInterface.isexpr(ex)
return ex
end
end
return ex
end
Expand Down
197 changes: 143 additions & 54 deletions src/EGraphs/Schedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@ module Schedulers

include("../docstrings.jl")

using Metatheory
using Metatheory.Rules
using Metatheory.EGraphs
using Metatheory.Patterns
using DocStringExtensions

import Metatheory: UNDEF_ID_VEC
import Metatheory.EGraphs: IdKey

export AbstractScheduler,
SimpleScheduler, BackoffScheduler, FreezingScheduler, ScoredScheduler, cansaturate, cansearch, inform!, setiter!
SimpleScheduler, BackoffScheduler, FreezingScheduler, ScoredScheduler, search_matches!, cansaturate, setiter!

"""
Represents a rule scheduler for the equality saturation process

"""
abstract type AbstractScheduler end


"""
cansaturate(s::AbstractScheduler)

Expand All @@ -24,26 +29,12 @@ Should return `true` if the e-graph can be said to be saturated
function cansaturate end

"""
cansearch(s::AbstractScheduler, i::Int)
cansearch(s::AbstractScheduler, i::Int, eclass_id::Id)

Given a theory `t` and a rule `r` with index `i` in the theory,
should return `false` if the search for rule with index `i` should be skipped
for the current iteration. An extra `eclass_id::Id` arguments can be passed
in order to filter out specific e-classes.
"""
function cansearch end
search_matches!(s::AbstractScheduler, ematch_buffer::OptBuffer{UInt128}, rule_idx::Int)

Uses the scheduler `s` to search for matches for rule with index `rule_idx`.
Matches are stored in the ematch_buffer. Returns the number of matches.
"""
inform!(s::AbstractScheduler, i::Int, n_matches)
inform!(s::AbstractScheduler, i::Int, eclass_id::Id, n_matches)


Given a theory `t` and a rule `r` with index `i` in the theory,
This function is called **after** pattern matching (searching) the e-graph,
it informs the scheduler about the number of yielded matches.
"""
function inform! end
function search_matches! end

"""
setiter!(s::AbstractScheduler, i::Int)
Expand All @@ -55,7 +46,7 @@ function setiter! end
"""
rebuild!(s::AbstractScheduler, g::EGraph)

Some schedulers may hold data that need to be re-canonicalized
Some schedulers may hold data that need to be re-canonicalized
after an iteration of equality saturation, such as references to e-class IDs.
This is called by equality saturation after e-graph `rebuild!`
"""
Expand All @@ -65,13 +56,33 @@ function rebuild! end
# Defaults
# ===========================================================================

@inline inform!(::AbstractScheduler, ::Int, ::Int) = nothing
@inline inform!(::AbstractScheduler, ::Int, ::Id, ::Int) = nothing
@inline search_matches!(::AbstractScheduler, ::OptBuffer{UInt128}, ::Int) = 0
@inline cansaturate(::AbstractScheduler) = true
@inline setiter!(::AbstractScheduler, ::Int) = nothing
@inline rebuild!(::AbstractScheduler) = nothing




function cached_ids(g::EGraph, p::PatExpr)::Vector{Id}
if isground(p)
id = lookup_pat(g, p)
iszero(id) ? UNDEF_ID_VEC : [id]
else
get(g.classes_by_op, IdKey(v_signature(p.n)), UNDEF_ID_VEC)
end
end

function cached_ids(g::EGraph, p::PatLiteral)
id = lookup_pat(g, p)
id > 0 && return [id]
return UNDEF_ID_VEC
end

cached_ids(g::EGraph, ::PatVar) = Iterators.map(x -> x.val, keys(g.classes))



# ===========================================================================
# SimpleScheduler
# ===========================================================================
Expand All @@ -80,13 +91,31 @@ function rebuild! end
"""
A simple Rewrite Scheduler that applies every rule every time
"""
struct SimpleScheduler <: AbstractScheduler end

SimpleScheduler(::EGraph, ::Theory) = SimpleScheduler()
struct SimpleScheduler <: AbstractScheduler
g::EGraph
theory::Theory
end

@inline cansaturate(s::SimpleScheduler) = true
@inline cansearch(s::SimpleScheduler, ::Int) = true
@inline cansearch(s::SimpleScheduler, ::Int, ::Id) = true

"""
Apply all rules to all eclasses.
"""
function search_matches!(s::SimpleScheduler,
ematch_buffer::OptBuffer{UInt128},
rule_idx::Int)
n_matches = 0
rule = s.theory[rule_idx]
for i in cached_ids(s.g, rule.left)
n_matches += rule.ematcher_left!(s.g, rule_idx, i, rule.stack, ematch_buffer)
end
if is_bidirectional(rule)
for i in cached_ids(s.g, rule.right)
n_matches += rule.ematcher_right!(s.g, rule_idx, i, rule.stack, ematch_buffer)
end
end
n_matches
end

# ===========================================================================
# BackoffScheduler
Expand All @@ -103,42 +132,73 @@ This seems effective at preventing explosive rules like
associativity from taking an unfair amount of resources.
"""
Base.@kwdef mutable struct BackoffScheduler <: AbstractScheduler
data::Vector{Tuple{Int,Int}} # TimesBanned ⊗ BannedUntil
g::EGraph
theory::Theory
const data::Vector{Tuple{Int,Int}} # TimesBanned ⊗ BannedUntil
const g::EGraph
const theory::Theory
const match_limit::Int = 1000
const ban_length::Int = 5
curr_iter::Int = 1
match_limit::Int = 1000
ban_length::Int = 5
end

@inline cansearch(s::BackoffScheduler, rule_idx::Int)::Bool = s.curr_iter > last(s.data[rule_idx])
@inline cansearch(s::BackoffScheduler, rule_idx::Int, eclass_id::Id) = true

BackoffScheduler(g::EGraph, theory::Theory; kwargs...) =
BackoffScheduler(; data = fill((0, 0), length(theory)), g, theory, kwargs...)

# can saturate if there's no banned rule
cansaturate(s::BackoffScheduler)::Bool = all((<)(s.curr_iter) ∘ last, s.data)

function setiter!(s::BackoffScheduler, curr_iter::Int)
s.curr_iter = curr_iter
end


function search_matches!(s::BackoffScheduler,
ematch_buffer::OptBuffer{UInt128},
rule_idx::Int)

(times_banned, banned_until) = s.data[rule_idx]
rule = s.theory[rule_idx]

if s.curr_iter < banned_until
@debug "Skipping $rule (banned $times_banned x) until $banned_until."
return 0
end

function inform!(s::BackoffScheduler, rule_idx::Int, n_matches::Int)
(times_banned, _) = s.data[rule_idx]
threshold = s.match_limit << times_banned
n_matches = 0
old_ematch_buffer_size = length(ematch_buffer)
# Search matches in the egraph with the theshold (+1) as a limit.
# Stop early when we found more matches than the threshold
for i in cached_ids(s.g, rule.left)
eclass_matches = rule.ematcher_left!(s.g, rule_idx, i, rule.stack, ematch_buffer, threshold + 1 - n_matches)
n_matches += eclass_matches
n_matches <= threshold || break
end
if is_bidirectional(rule) && n_matches <= threshold
for i in cached_ids(s.g, rule.right)
eclass_matches = rule.ematcher_right!(s.g, rule_idx, i, rule.stack, ematch_buffer, threshold + 1 - n_matches)
n_matches += eclass_matches
n_matches <= threshold || break
end
end

if n_matches > threshold
s.data[rule_idx] = (times_banned += 1, s.curr_iter + (s.ban_length << times_banned))
ban_length = s.ban_length << times_banned
banned_until = s.curr_iter + ban_length
@debug "Banning $rule (banned $times_banned times) for $ban_length iterations (threshold: $threshold < $n_matches matches)."
s.data[rule_idx] = (times_banned + 1, banned_until)
# revert matches because the rule could be matched to eclasses only partially
resize!(ematch_buffer, old_ematch_buffer_size)
return 0
end
end

function setiter!(s::BackoffScheduler, curr_iter::Int)
s.curr_iter = curr_iter
n_matches
end


# ===========================================================================
# FreezingScheduler
# ===========================================================================

struct FreezingSchedulerStat
mutable struct FreezingSchedulerStat
times_banned::Int
banned_until::Int
size_limit::Int
Expand All @@ -147,20 +207,17 @@ end

Base.@kwdef mutable struct FreezingScheduler <: AbstractScheduler
data::Dict{Id,FreezingSchedulerStat} = Dict{Id,FreezingSchedulerStat}()
g::EGraph
theory::Theory
const g::EGraph
const theory::Theory
const default_eclass_size_limit::Int = 10
const default_eclass_size_increment::Int = 3
const default_eclass_ban_length::Int = 3
const default_eclass_ban_increment::Int = 2
curr_iter::Int = 1
default_eclass_size_limit::Int = 10
default_eclass_size_increment::Int = 3
default_eclass_ban_length::Int = 3
default_eclass_ban_increment::Int = 2
end

FreezingScheduler(g::EGraph, theory::Theory; kwargs...) = FreezingScheduler(; g, theory, kwargs...)

@inline cansearch(s::FreezingScheduler, rule_idx::Int)::Bool = true
@inline cansearch(s::FreezingScheduler, ::Int, eclass_id::Id) = s.curr_iter > s[eclass_id].banned_until

function Base.getindex(s::FreezingScheduler, id::Id)
haskey(s.data, id) && return s.data[id]
nid = find(s.g, id)
Expand All @@ -172,18 +229,50 @@ end
# can saturate if there's no banned rule
cansaturate(s::FreezingScheduler)::Bool = all(stat -> stat.banned_until < s.curr_iter, values(s.data))

function inform!(s::FreezingScheduler, rule_idx::Int, n_matches::Int, eclass_id::Id)
function cansearch!(s::FreezingScheduler, eclass_id)
stats = s[eclass_id]
if s.curr_iter < stats.banned_until
@debug "Skipping eclass $eclass_id (banned $(stats.times_banned) times) until $(stats.banned_until)."
return false
end

threshold = stats.size_limit + s.default_eclass_size_increment * stats.times_banned
len = length(s.g[eclass_id])

if len > threshold
ban_length = stats.ban_length + s.default_eclass_ban_increment * stats.times_banned
stats.times_banned += 1
stats.banned_until = s.curr_iter + ban_length
@debug "Banning eclass $eclass_id (banned $(stats.times_banned) times) for $ban_length iterations (threshold: $threshold < $len nodes))."

return false
end

true
end

function search_matches!(s::FreezingScheduler,
ematch_buffer::OptBuffer{UInt128},
rule_idx::Int)
n_matches = 0
rule = s.theory[rule_idx]
for i in cached_ids(s.g, rule.left)
if cansearch!(s, i)
n_matches += rule.ematcher_left!(s.g, rule_idx, i, rule.stack, ematch_buffer)
end
end

# repeat for RHS if bidirectional
if is_bidirectional(rule)
for i in cached_ids(s.g, rule.right)
if cansearch!(s, i)
n_matches += rule.ematcher_right!(s.g, rule_idx, i, rule.stack, ematch_buffer)
end
end
end
n_matches
end


function setiter!(s::FreezingScheduler, curr_iter::Int)
s.curr_iter = curr_iter
end
Expand Down
2 changes: 0 additions & 2 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# ==============================================================
# Interface to implement for custom analyses
# ==============================================================

"""
modify!(eclass::EClass{Analysis})

Expand Down Expand Up @@ -464,7 +463,6 @@ function rebuild_classes!(g::EGraph)
end

for (eclass_id, eclass) in g.classes
# old_len = length(eclass.nodes)
for n in eclass.nodes
canonicalize!(g, n)
end
Expand Down
Loading
Loading