From 329595ef0cce30ac275ef10a167ec67d31f8f3c3 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 10 Jan 2024 18:08:08 +0100 Subject: [PATCH] use UInt as id --- src/EGraphs/egraph.jl | 44 +++++++++++++++++++-------------------- src/EGraphs/saturation.jl | 18 +++++++++------- src/EGraphs/unionfind.jl | 10 ++++----- src/Patterns.jl | 32 +++++++++++++++------------- src/TermInterface.jl | 9 ++++++-- src/ematch_compiler.jl | 28 ++++++++++++------------- test/egraphs/egraphs.jl | 20 +++++++++--------- test/egraphs/unionfind.jl | 18 +++++++--------- 8 files changed, 94 insertions(+), 85 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 04086606..71e4461c 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -29,11 +29,10 @@ Given an ENode `n`, `make` should return the corresponding analysis value. """ function make end -const EClassId = Int64 -const TermTypes = Dict{Tuple{Any,Int},Type} +const EClassId = UInt64 # TODO document bindings -const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}} -const UNDEF_ARGS = Vector{EClassId}(undef, 0) +const Bindings = Base.ImmutableDict{Int,Tuple{EClassId,Int}} +const UNDEF_ID_VEC = Vector{EClassId}(undef, 0) # @compactify begin struct ENode @@ -44,7 +43,7 @@ struct ENode args::Vector{EClassId} hash::Ref{UInt} ENode(head, operation, args) = new(true, head, operation, args, Ref{UInt}(0)) - ENode(literal) = new(false, nothing, literal, UNDEF_ARGS, Ref{UInt}(0)) + ENode(literal) = new(false, nothing, literal, UNDEF_ID_VEC, Ref{UInt}(0)) end TermInterface.istree(n::ENode) = n.istree @@ -52,7 +51,7 @@ TermInterface.head(n::ENode) = n.head TermInterface.operation(n::ENode) = n.operation TermInterface.arguments(n::ENode) = n.args TermInterface.children(n::ENode) = [n.operation; n.args...] -TermInterface.arity(n::ENode) = length(n.args) +TermInterface.arity(n::ENode)::Int = length(n.args) # This optimization comes from SymbolicUtils @@ -78,7 +77,7 @@ end Base.show(io::IO, x::ENode) = print(io, to_expr(x)) -function op_key(n) +function op_key(n)::Pair{Any,Int} op = operation(n) (op isa Union{Function,DataType} ? nameof(op) : op) => (istree(n) ? arity(n) : -1) end @@ -155,7 +154,7 @@ mutable struct EGraph{Head,Analysis} "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." buffer::Vector{Bindings} "Buffer for rule application which defaults to a global. Use a local buffer for generated functions." - merges_buffer::Vector{Tuple{Int,Int}} + merges_buffer::Vector{EClassId} lock::ReentrantLock end @@ -167,16 +166,16 @@ Construct an EGraph from a starting symbolic expression `expr`. function EGraph{Head,Analysis}(; needslock::Bool = false) where {Head,Analysis} EGraph{Head,Analysis}( UnionFind(), - Dict{EClassId,EClass}(), + Dict{EClassId,EClass{Analysis}}(), Dict{ENode,EClassId}(), Pair{ENode,EClassId}[], UniqueQueue{Pair{ENode,EClassId}}(), - -1, + 0, Dict{Pair{Any,Int},Vector{EClassId}}(), false, needslock, Bindings[], - Tuple{Int,Int}[], + EClassId[], ReentrantLock(), ) end @@ -232,7 +231,7 @@ end function lookup(g::EGraph, n::ENode)::EClassId cc = canonicalize(g, n) - haskey(g.memo, cc) ? find(g, g.memo[cc]) : -1 + haskey(g.memo, cc) ? find(g, g.memo[cc]) : 0 end @@ -288,26 +287,22 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int [`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly insert the literal into the [`EGraph`](@ref). """ -function addexpr!(g::EGraph, se, keepmeta = false)::EClassId +function addexpr!(g::EGraph, se)::EClassId se isa EClass && return se.id e = preprocess(se) n = if istree(se) args = arguments(e) - ar = length(args) + ar = arity(e) class_ids = Vector{EClassId}(undef, ar) for i in 1:ar - @inbounds class_ids[i] = addexpr!(g, args[i], keepmeta) + @inbounds class_ids[i] = addexpr!(g, args[i]) end ENode(head(e), operation(e), class_ids) else # constant enode ENode(e) end id = add!(g, n) - if keepmeta - meta = TermInterface.metadata(e) - !isnothing(meta) && setdata!(g.classes[id], :metadata_analysis, meta) - end return id end @@ -512,15 +507,18 @@ function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head} eh = Head(head_symbol(head(p))) - ids = map(x -> lookup_pat(g, x), args) - !all((>)(0), ids) && return -1 + ids = Vector{EClassId}(undef, ar) + for i in 1:ar + @inbounds ids[i] = lookup_pat(g, args[i]) + ids[i] <= 0 && return 0 + end if Head == ExprHead && op isa Union{Function,DataType} id = lookup(g, ENode(eh, op, ids)) - id < 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id + id <= 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id else lookup(g, ENode(eh, op, ids)) end end -lookup_pat(g::EGraph, p::Any) = lookup(g, ENode(p)) +lookup_pat(g::EGraph, p::Any)::EClassId = lookup(g, ENode(p)) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 734e2d95..6586dbcf 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -38,12 +38,12 @@ Base.@kwdef mutable struct SaturationParams timer::Bool = true end -function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64} +function cached_ids(g::EGraph, p::PatTerm)::Vector{EClassId} if isground(p) id = lookup_pat(g, p) !isnothing(id) && return [id] else - get(g.classes_by_op, op_key(p), ()) + get(g.classes_by_op, op_key(p), UNDEF_ID_VEC) end end @@ -115,13 +115,15 @@ function instantiate_enode!(bindings::Bindings, g::EGraph{Head}, p::PatTerm)::EC end function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction) - push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right))) + push!(g.merges_buffer, id) + push!(g.merges_buffer, instantiate_enode!(buf, g, rule.right)) nothing end function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int) pat_to_inst = direction == 1 ? rule.right : rule.left - push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst))) + push!(g.merges_buffer, id) + push!(g.merges_buffer, instantiate_enode!(bindings, g, pat_to_inst)) nothing end @@ -156,7 +158,8 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) isnothing(r) && return nothing rcid = addexpr!(g, r) - push!(g.merges_buffer, (id, rcid)) + push!(g.merges_buffer, id) + push!(g.merges_buffer, rcid) return nothing end @@ -177,7 +180,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end bindings = pop!(g.buffer) - rule_idx, id = bindings[0] + id, rule_idx = bindings[0] direction = sign(rule_idx) rule_idx = abs(rule_idx) rule = theory[rule_idx] @@ -198,7 +201,8 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end maybelock!(g) do while !isempty(g.merges_buffer) - (l, r) = pop!(g.merges_buffer) + l = pop!(g.merges_buffer) + r = pop!(g.merges_buffer) union!(g, l, r) end end diff --git a/src/EGraphs/unionfind.jl b/src/EGraphs/unionfind.jl index 0e19aa31..e2aa6ada 100644 --- a/src/EGraphs/unionfind.jl +++ b/src/EGraphs/unionfind.jl @@ -1,10 +1,10 @@ struct UnionFind - parents::Vector{Int} + parents::Vector{UInt} end -UnionFind() = UnionFind(Int[]) +UnionFind() = UnionFind(UInt[]) -function Base.push!(uf::UnionFind) +function Base.push!(uf::UnionFind)::UInt l = length(uf.parents) + 1 push!(uf.parents, l) l @@ -12,12 +12,12 @@ end Base.length(uf::UnionFind) = length(uf.parents) -function Base.union!(uf::UnionFind, i::Int, j::Int) +function Base.union!(uf::UnionFind, i::UInt, j::UInt) uf.parents[j] = i i end -function find(uf::UnionFind, i::Int) +function find(uf::UnionFind, i::UInt) while i != uf.parents[i] i = uf.parents[i] end diff --git a/src/Patterns.jl b/src/Patterns.jl index 2179cf65..546864b7 100644 --- a/src/Patterns.jl +++ b/src/Patterns.jl @@ -13,7 +13,7 @@ abstract type AbstractPat end struct PatHead head end -TermInterface.head_symbol(p::PatHead) = p.head +TermInterface.head_symbol(p::PatHead)::Symbol = p.head PatHead(p::PatHead) = error("recursive!") @@ -83,34 +83,38 @@ symbol `operation` and expression head `head.head`. struct PatTerm <: AbstractPat head::PatHead children::Vector - PatTerm(h, t::Vector) = new(h, t) + isground::Bool + PatTerm(h, t::Vector) = new(h, t, all(isground, t)) end PatTerm(eh, op) = PatTerm(eh, [op]) PatTerm(eh, children...) = PatTerm(eh, collect(children)) + +isground(p::PatTerm)::Bool = p.isground + TermInterface.istree(::PatTerm) = true TermInterface.head(p::PatTerm)::PatHead = p.head TermInterface.children(p::PatTerm) = p.children function TermInterface.operation(p::PatTerm) hs = head_symbol(head(p)) - hs == :call && return first(p.children) + hs in (:call, :macrocall) && return first(p.children) # hs == :ref && return getindex hs end function TermInterface.arguments(p::PatTerm) hs = head_symbol(head(p)) - hs == :call ? @view(p.children[2:end]) : p.children + hs in (:call, :macrocall) ? @view(p.children[2:end]) : p.children +end +function TermInterface.arity(p::PatTerm) + hs = head_symbol(head(p)) + l = length(p.children) + hs in (:call, :macrocall) ? l - 1 : l end -TermInterface.arity(p::PatTerm) = length(arguments(p)) TermInterface.metadata(p::PatTerm) = nothing TermInterface.maketerm(head::PatHead, children; type = Any, metadata = nothing) = PatTerm(head, children...) -isground(p::PatTerm) = all(isground, p.children) - - -# ============================================== -# ================== PATTERN VARIABLES ========= -# ============================================== +# --------------------- +# # Pattern Variables. """ Collects pattern variables appearing in a pattern into a vector of symbols @@ -122,9 +126,9 @@ patvars(x, s) = s patvars(p) = unique!(patvars(p, Symbol[])) -# ============================================== -# ================== DEBRUJIN INDEXING ========= -# ============================================== +# --------------------- +# # Debrujin Indexing. + function setdebrujin!(p::Union{PatVar,PatSegment}, pvars) p.idx = findfirst((==)(p.name), pvars) diff --git a/src/TermInterface.jl b/src/TermInterface.jl index cc17e5a9..d2fe5e59 100644 --- a/src/TermInterface.jl +++ b/src/TermInterface.jl @@ -114,7 +114,7 @@ export unsorted_arguments Returns the number of arguments of `x`. Implicitly defined if `arguments(x)` is defined. """ -arity(x) = length(arguments(x)) +arity(x)::Int = length(arguments(x)) export arity @@ -220,7 +220,7 @@ struct ExprHead end export ExprHead -head_symbol(eh::ExprHead) = eh.head +head_symbol(eh::ExprHead)::Symbol = eh.head istree(x::Expr) = true head(e::Expr) = ExprHead(e.head) @@ -247,6 +247,11 @@ function arguments(e::Expr) end end +function arity(e::Expr)::Int + l = length(e.args) + e.head in (:call, :macrocall) ? l - 1 : l +end + function maketerm(head::ExprHead, children; type = Any, metadata = nothing) if !isempty(children) && first(children) isa Union{Function,DataType} Expr(head.head, nameof(first(children)), @view(children[2:end])...) diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index 0e5088fd..b6b19d04 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -32,7 +32,7 @@ end function predicate_ematcher(p::PatVar, pred) function predicate_ematcher(next, g, data, bindings) !islist(data) && return - id::Int = car(data) + id::UInt = car(data) eclass = g[id] if pred(eclass) enode_idx = 0 @@ -122,27 +122,27 @@ function ematcher(p::PatTerm) end -const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}() +const EMPTY_BINDINGS = Base.ImmutableDict{Int,Tuple{UInt,Int}}() """ -Substitutions are efficiently represented in memory as vector of tuples of two integers. -This should allow for static allocation of matches and use of LoopVectorization.jl -The buffer has to be fairly big when e-matching. -The size of the buffer should double when there's too many matches. -The format is as follows -* The first pair denotes the index of the rule in the theory and the e-class id - of the node of the e-graph that is being substituted. The rule number should be negative if it's a bidirectional - the direction is right-to-left. -* From the second pair on, it represents (e-class id, literal position) at the position of the pattern variable -* The end of a substitution is delimited by (0,0) +Substitutions are efficiently represented in memory as immutable dictionaries of tuples of two integers. + +The format is as follows: + +bindings[0] holds + 1. e-class-id of the node of the e-graph that is being substituted. + 2. the index of the rule in the theory. The rule number should be negative + if it's a bidirectional rule and the direction is right-to-left. + +The rest of the immutable dictionary bindings[n>0] represents (e-class id, literal position) at the position of the pattern variable `n`. """ function ematcher_yield(p, npvars::Int, direction::Int) em = ematcher(p) function ematcher_yield(g, rule_idx, id)::Int n_matches = 0 - em(g, (id,), EMPTY_ECLASS_DICT) do b, n + em(g, (id,), EMPTY_BINDINGS) do b, n maybelock!(g) do - push!(g.buffer, assoc(b, 0, (rule_idx * direction, id))) + push!(g.buffer, assoc(b, 0, (id, rule_idx * direction))) n_matches += 1 end end diff --git a/test/egraphs/egraphs.jl b/test/egraphs/egraphs.jl index 9bd20711..493066bb 100644 --- a/test/egraphs/egraphs.jl +++ b/test/egraphs/egraphs.jl @@ -7,8 +7,8 @@ using Metatheory testmatch = :(a << 1) g = EGraph(testexpr) t2 = addexpr!(g, testmatch) - union!(g, t2, 3) - @test find(g, t2) == find(g, 3) + union!(g, t2, EClassId(3)) + @test find(g, t2) == find(g, EClassId(3)) # DOES NOT UPWARD MERGE end @@ -43,8 +43,8 @@ end t1 = addexpr!(g, apply(6, f, :a)) t2 = addexpr!(g, apply(9, f, :a)) - c_id = union!(g, t1, 1) # a == apply(6,f,a) - c2_id = union!(g, t2, 1) # a == apply(9,f,a) + c_id = union!(g, t1, EClassId(1)) # a == apply(6,f,a) + c2_id = union!(g, t2, EClassId(1)) # a == apply(9,f,a) rebuild!(g) @@ -52,10 +52,10 @@ end t4 = addexpr!(g, apply(7, f, :a)) # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a - @test find(g, t1) == find(g, 1) - @test find(g, t2) == find(g, 1) - @test find(g, t3) == find(g, 1) - @test find(g, t4) != find(g, 1) + @test find(g, t1) == find(g, EClassId(1)) + @test find(g, t2) == find(g, EClassId(1)) + @test find(g, t3) == find(g, EClassId(1)) + @test find(g, t4) != find(g, EClassId(1)) # if m or n is prime, f(a) = a t5 = addexpr!(g, apply(11, f, :a)) @@ -64,6 +64,6 @@ end rebuild!(g) - @test find(g, t5) == find(g, 1) - @test find(g, t6) == find(g, 1) + @test find(g, t5) == find(g, EClassId(1)) + @test find(g, t6) == find(g, EClassId(1)) end diff --git a/test/egraphs/unionfind.jl b/test/egraphs/unionfind.jl index 24fc4013..cf151e30 100644 --- a/test/egraphs/unionfind.jl +++ b/test/egraphs/unionfind.jl @@ -8,17 +8,15 @@ for _ in 1:n push!(uf) end -union!(uf, 1, 2) -union!(uf, 1, 3) -union!(uf, 1, 4) +union!(uf, UInt(1), UInt(2)) +union!(uf, UInt(1), UInt(3)) +union!(uf, UInt(1), UInt(4)) -union!(uf, 6, 8) -union!(uf, 6, 9) -union!(uf, 6, 10) +union!(uf, UInt(6), UInt(8)) +union!(uf, UInt(6), UInt(9)) +union!(uf, UInt(6), UInt(10)) for i in 1:n - find(uf, i) + find(uf, UInt(i)) end -@test uf.parents == [1, 1, 1, 1, 5, 6, 7, 6, 6, 6] - -# TODO test path compression \ No newline at end of file +@test uf.parents == UInt[1, 1, 1, 1, 5, 6, 7, 6, 6, 6]