Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ale/3.0 new compact rule #206

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/rewrite.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ The `@rule` macro takes a pair of patterns -- the _matcher_ and the _consequent

**Rule operators**:
- `LHS => RHS`: create a `DynamicRule`. The RHS is *evaluated* on rewrite.
- `LHS --> RHS`: create a `RewriteRule`. The RHS is **not** evaluated but *symbolically substituted* on rewrite.
- `LHS == RHS`: create a `EqualityRule`. In e-graph rewriting, this rule behaves like `RewriteRule` but can go in both directions. Doesn't work in classical rewriting.
- `LHS --> RHS`: create a `DirectedRule`. The RHS is **not** evaluated but *symbolically substituted* on rewrite.
- `LHS == RHS`: create a `EqualityRule`. In e-graph rewriting, this rule behaves like `DirectedRule` but can go in both directions. Doesn't work in classical rewriting.
- `LHS ≠ RHS`: create a `UnequalRule`. Can only be used in e-graphs, and is used to eagerly stop the process of rewriting if LHS is found to be equal to RHS.


Expand Down
10 changes: 5 additions & 5 deletions src/EGraphs/Schedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ struct SimpleScheduler <: AbstractScheduler end

cansaturate(s::SimpleScheduler) = true
cansearch(s::SimpleScheduler, r::Int) = true
function SimpleScheduler(G::EGraph, theory::Vector{<:AbstractRule})
function SimpleScheduler(::EGraph, ::Vector{RewriteRule})
SimpleScheduler()
end
inform!(s::SimpleScheduler, r, n_matches) = nothing
setiter!(s::SimpleScheduler, iteration) = nothing
inform!(::SimpleScheduler, r, n_matches) = nothing
setiter!(::SimpleScheduler, iteration) = nothing


# ===========================================================================
Expand All @@ -94,7 +94,7 @@ associativity from taking an unfair amount of resources.
mutable struct BackoffScheduler <: AbstractScheduler
data::Vector{Tuple{Int,Int}} # TimesBanned ⊗ BannedUntil
G::EGraph
theory::Vector{<:AbstractRule}
theory::Vector{RewriteRule}
curr_iter::Int
match_limit::Int
ban_length::Int
Expand All @@ -103,7 +103,7 @@ end
cansearch(s::BackoffScheduler, rule_idx::Int)::Bool = s.curr_iter > last(s.data[rule_idx])


function BackoffScheduler(G::EGraph, theory::Vector{<:AbstractRule}, match_limit::Int = 1000, ban_length::Int = 5)
function BackoffScheduler(G::EGraph, theory::Vector{RewriteRule}, match_limit::Int = 1000, ban_length::Int = 5)
BackoffScheduler(fill((0, 0), length(theory)), G, theory, 1, match_limit, ban_length)
end

Expand Down
34 changes: 16 additions & 18 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Returns an iterator of `Match`es.
"""
function eqsat_search!(
g::EGraph,
theory::Vector{<:AbstractRule},
theory::Vector{RewriteRule},
scheduler::AbstractScheduler,
report::SaturationReport,
ematch_buffer::OptBuffer{UInt128},
Expand All @@ -83,20 +83,19 @@ function eqsat_search!(
continue
end
ids_left = cached_ids(g, rule.left)
ids_right = rule isa BidirRule ? cached_ids(g, rule.right) : UNDEF_ID_VEC
ids_right = is_bidirectional(rule) ? cached_ids(g, rule.right) : UNDEF_ID_VEC

if rule isa BidirRule
for i in ids_left
n_matches += rule.ematcher_new_left!(g, rule_idx, i, rule.stack, ematch_buffer)
end

for i in ids_left
n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer)
end

if is_bidirectional(rule)
for i in ids_right
n_matches += rule.ematcher_new_right!(g, rule_idx, i, rule.stack, ematch_buffer)
end
else
for i in ids_left
n_matches += rule.ematcher!(g, rule_idx, i, rule.stack, ematch_buffer)
n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer)
end
end

n_matches - prev_matches > 0 && @debug "Rule $rule_idx: $rule produced $(n_matches - prev_matches) matches"
# if n_matches - prev_matches > 2 && rule_idx == 2
# @debug buffer_readable(g, old_len)
Expand Down Expand Up @@ -134,7 +133,7 @@ function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id
add!(g, p.n, true)
end

function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction, merges_buffer::OptBuffer{UInt128})
function apply_rule!(buf, g::EGraph, rule::DirectedRule, id, direction, merges_buffer::OptBuffer{UInt128})
new_id::Id = instantiate_enode!(buf, g, rule.right)
push!(merges_buffer, v_pair(new_id, id))
nothing
Expand Down Expand Up @@ -175,8 +174,7 @@ function instantiate_actual_param!(bindings, g::EGraph, i)
end

function apply_rule!(bindings, g::EGraph, rule::DynamicRule, id::Id, direction::Int, merges_buffer::OptBuffer{UInt128})
f = rule.rhs_fun
r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
r = rule.right(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
isnothing(r) && return nothing
rcid = addexpr!(g, r)
push!(merges_buffer, v_pair(rcid, id))
Expand All @@ -187,7 +185,7 @@ const CHECK_GOAL_EVERY_N_MATCHES = 20

function eqsat_apply!(
g::EGraph,
theory::Vector{<:AbstractRule},
theory::Vector{RewriteRule},
rep::SaturationReport,
params::SaturationParams,
ematch_buffer::OptBuffer{UInt128},
Expand Down Expand Up @@ -273,7 +271,7 @@ Core algorithm of the library: the equality saturation step.
"""
function eqsat_step!(
g::EGraph,
theory::Vector{<:AbstractRule},
theory::Vector{RewriteRule},
curr_iter,
scheduler::AbstractScheduler,
params::SaturationParams,
Expand Down Expand Up @@ -302,7 +300,7 @@ end
Given an [`EGraph`](@ref) and a collection of rewrite rules,
execute the equality saturation algorithm.
"""
function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = SaturationParams())
function saturate!(g::EGraph, theory::Vector{RewriteRule}, params = SaturationParams())
curr_iter = 0

sched = params.scheduler(g, theory, params.schedulerparams...)
Expand Down Expand Up @@ -366,7 +364,7 @@ function areequal(theory::Vector, exprs...; params = SaturationParams())
areequal(g, theory, exprs...; params)
end

function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params = SaturationParams())
function areequal(g::EGraph, t::Vector{RewriteRule}, exprs...; params = SaturationParams())
n = length(exprs)
n == 1 && return true

Expand Down
17 changes: 12 additions & 5 deletions src/Library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,26 @@ end

macro associativity(op)
esc(quote
[(@left_associative $op), (@right_associative $op)]
RewriteRule[(@left_associative $op), (@right_associative $op)]
end)
end

macro monoid(op, id)
esc(quote
[(@left_associative($op)), (@right_associative($op)), (@identity_left($op, $id)), (@identity_right($op, $id))]
end)
esc(
quote
RewriteRule[
(@left_associative($op)),
(@right_associative($op)),
(@identity_left($op, $id)),
(@identity_right($op, $id)),
]
end,
)
end

macro commutative_monoid(op, id)
esc(quote
[(@commutativity $op), (@left_associative $op), (@right_associative $op), (@identity_left $op $id)]
RewriteRule[(@commutativity $op), (@left_associative $op), (@right_associative $op), (@identity_left $op $id)]
end)
end

Expand Down
Loading