Skip to content

Commit

Permalink
use UInt as id
Browse files Browse the repository at this point in the history
  • Loading branch information
a committed Jan 10, 2024
1 parent 0bd3021 commit 329595e
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 85 deletions.
44 changes: 21 additions & 23 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ Given an ENode `n`, `make` should return the corresponding analysis value.
"""
function make end

const EClassId = Int64
const TermTypes = Dict{Tuple{Any,Int},Type}
const EClassId = UInt64
# TODO document bindings
const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}}
const UNDEF_ARGS = Vector{EClassId}(undef, 0)
const Bindings = Base.ImmutableDict{Int,Tuple{EClassId,Int}}
const UNDEF_ID_VEC = Vector{EClassId}(undef, 0)

# @compactify begin
struct ENode
Expand All @@ -44,15 +43,15 @@ struct ENode
args::Vector{EClassId}
hash::Ref{UInt}
ENode(head, operation, args) = new(true, head, operation, args, Ref{UInt}(0))
ENode(literal) = new(false, nothing, literal, UNDEF_ARGS, Ref{UInt}(0))
ENode(literal) = new(false, nothing, literal, UNDEF_ID_VEC, Ref{UInt}(0))
end

TermInterface.istree(n::ENode) = n.istree
TermInterface.head(n::ENode) = n.head
TermInterface.operation(n::ENode) = n.operation
TermInterface.arguments(n::ENode) = n.args
TermInterface.children(n::ENode) = [n.operation; n.args...]
TermInterface.arity(n::ENode) = length(n.args)
TermInterface.arity(n::ENode)::Int = length(n.args)


# This optimization comes from SymbolicUtils
Expand All @@ -78,7 +77,7 @@ end

Base.show(io::IO, x::ENode) = print(io, to_expr(x))

function op_key(n)
function op_key(n)::Pair{Any,Int}
op = operation(n)
(op isa Union{Function,DataType} ? nameof(op) : op) => (istree(n) ? arity(n) : -1)
end
Expand Down Expand Up @@ -155,7 +154,7 @@ mutable struct EGraph{Head,Analysis}
"Buffer for e-matching which defaults to a global. Use a local buffer for generated functions."
buffer::Vector{Bindings}
"Buffer for rule application which defaults to a global. Use a local buffer for generated functions."
merges_buffer::Vector{Tuple{Int,Int}}
merges_buffer::Vector{EClassId}
lock::ReentrantLock
end

Expand All @@ -167,16 +166,16 @@ Construct an EGraph from a starting symbolic expression `expr`.
function EGraph{Head,Analysis}(; needslock::Bool = false) where {Head,Analysis}
EGraph{Head,Analysis}(
UnionFind(),
Dict{EClassId,EClass}(),
Dict{EClassId,EClass{Analysis}}(),
Dict{ENode,EClassId}(),
Pair{ENode,EClassId}[],
UniqueQueue{Pair{ENode,EClassId}}(),
-1,
0,
Dict{Pair{Any,Int},Vector{EClassId}}(),
false,
needslock,
Bindings[],
Tuple{Int,Int}[],
EClassId[],
ReentrantLock(),
)
end
Expand Down Expand Up @@ -232,7 +231,7 @@ end

function lookup(g::EGraph, n::ENode)::EClassId
cc = canonicalize(g, n)
haskey(g.memo, cc) ? find(g, g.memo[cc]) : -1
haskey(g.memo, cc) ? find(g, g.memo[cc]) : 0
end


Expand Down Expand Up @@ -288,26 +287,22 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int
[`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly
insert the literal into the [`EGraph`](@ref).
"""
function addexpr!(g::EGraph, se, keepmeta = false)::EClassId
function addexpr!(g::EGraph, se)::EClassId
se isa EClass && return se.id
e = preprocess(se)

n = if istree(se)
args = arguments(e)
ar = length(args)
ar = arity(e)
class_ids = Vector{EClassId}(undef, ar)
for i in 1:ar
@inbounds class_ids[i] = addexpr!(g, args[i], keepmeta)
@inbounds class_ids[i] = addexpr!(g, args[i])
end
ENode(head(e), operation(e), class_ids)
else # constant enode
ENode(e)
end
id = add!(g, n)
if keepmeta
meta = TermInterface.metadata(e)
!isnothing(meta) && setdata!(g.classes[id], :metadata_analysis, meta)
end
return id
end

Expand Down Expand Up @@ -512,15 +507,18 @@ function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head}

eh = Head(head_symbol(head(p)))

ids = map(x -> lookup_pat(g, x), args)
!all((>)(0), ids) && return -1
ids = Vector{EClassId}(undef, ar)
for i in 1:ar
@inbounds ids[i] = lookup_pat(g, args[i])
ids[i] <= 0 && return 0
end

if Head == ExprHead && op isa Union{Function,DataType}
id = lookup(g, ENode(eh, op, ids))
id < 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id
id <= 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id
else
lookup(g, ENode(eh, op, ids))
end
end

lookup_pat(g::EGraph, p::Any) = lookup(g, ENode(p))
lookup_pat(g::EGraph, p::Any)::EClassId = lookup(g, ENode(p))
18 changes: 11 additions & 7 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ Base.@kwdef mutable struct SaturationParams
timer::Bool = true
end

function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64}
function cached_ids(g::EGraph, p::PatTerm)::Vector{EClassId}
if isground(p)
id = lookup_pat(g, p)
!isnothing(id) && return [id]
else
get(g.classes_by_op, op_key(p), ())
get(g.classes_by_op, op_key(p), UNDEF_ID_VEC)
end
end

Expand Down Expand Up @@ -115,13 +115,15 @@ function instantiate_enode!(bindings::Bindings, g::EGraph{Head}, p::PatTerm)::EC
end

function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction)
push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right)))
push!(g.merges_buffer, id)
push!(g.merges_buffer, instantiate_enode!(buf, g, rule.right))
nothing
end

function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int)
pat_to_inst = direction == 1 ? rule.right : rule.left
push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst)))
push!(g.merges_buffer, id)
push!(g.merges_buffer, instantiate_enode!(bindings, g, pat_to_inst))
nothing
end

Expand Down Expand Up @@ -156,7 +158,8 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas
r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
isnothing(r) && return nothing
rcid = addexpr!(g, r)
push!(g.merges_buffer, (id, rcid))
push!(g.merges_buffer, id)
push!(g.merges_buffer, rcid)
return nothing
end

Expand All @@ -177,7 +180,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
end

bindings = pop!(g.buffer)
rule_idx, id = bindings[0]
id, rule_idx = bindings[0]
direction = sign(rule_idx)
rule_idx = abs(rule_idx)
rule = theory[rule_idx]
Expand All @@ -198,7 +201,8 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
end
maybelock!(g) do
while !isempty(g.merges_buffer)
(l, r) = pop!(g.merges_buffer)
l = pop!(g.merges_buffer)
r = pop!(g.merges_buffer)
union!(g, l, r)
end
end
Expand Down
10 changes: 5 additions & 5 deletions src/EGraphs/unionfind.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
struct UnionFind
parents::Vector{Int}
parents::Vector{UInt}
end

UnionFind() = UnionFind(Int[])
UnionFind() = UnionFind(UInt[])

function Base.push!(uf::UnionFind)
function Base.push!(uf::UnionFind)::UInt
l = length(uf.parents) + 1
push!(uf.parents, l)
l
end

Base.length(uf::UnionFind) = length(uf.parents)

function Base.union!(uf::UnionFind, i::Int, j::Int)
function Base.union!(uf::UnionFind, i::UInt, j::UInt)
uf.parents[j] = i
i
end

function find(uf::UnionFind, i::Int)
function find(uf::UnionFind, i::UInt)
while i != uf.parents[i]
i = uf.parents[i]
end
Expand Down
32 changes: 18 additions & 14 deletions src/Patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract type AbstractPat end
struct PatHead
head
end
TermInterface.head_symbol(p::PatHead) = p.head
TermInterface.head_symbol(p::PatHead)::Symbol = p.head

PatHead(p::PatHead) = error("recursive!")

Expand Down Expand Up @@ -83,34 +83,38 @@ symbol `operation` and expression head `head.head`.
struct PatTerm <: AbstractPat
head::PatHead
children::Vector
PatTerm(h, t::Vector) = new(h, t)
isground::Bool
PatTerm(h, t::Vector) = new(h, t, all(isground, t))
end
PatTerm(eh, op) = PatTerm(eh, [op])
PatTerm(eh, children...) = PatTerm(eh, collect(children))

isground(p::PatTerm)::Bool = p.isground

TermInterface.istree(::PatTerm) = true
TermInterface.head(p::PatTerm)::PatHead = p.head
TermInterface.children(p::PatTerm) = p.children
function TermInterface.operation(p::PatTerm)
hs = head_symbol(head(p))
hs == :call && return first(p.children)
hs in (:call, :macrocall) && return first(p.children)
# hs == :ref && return getindex
hs
end
function TermInterface.arguments(p::PatTerm)
hs = head_symbol(head(p))
hs == :call ? @view(p.children[2:end]) : p.children
hs in (:call, :macrocall) ? @view(p.children[2:end]) : p.children
end
function TermInterface.arity(p::PatTerm)
hs = head_symbol(head(p))
l = length(p.children)
hs in (:call, :macrocall) ? l - 1 : l
end
TermInterface.arity(p::PatTerm) = length(arguments(p))
TermInterface.metadata(p::PatTerm) = nothing

TermInterface.maketerm(head::PatHead, children; type = Any, metadata = nothing) = PatTerm(head, children...)

isground(p::PatTerm) = all(isground, p.children)


# ==============================================
# ================== PATTERN VARIABLES =========
# ==============================================
# ---------------------
# # Pattern Variables.

"""
Collects pattern variables appearing in a pattern into a vector of symbols
Expand All @@ -122,9 +126,9 @@ patvars(x, s) = s
patvars(p) = unique!(patvars(p, Symbol[]))


# ==============================================
# ================== DEBRUJIN INDEXING =========
# ==============================================
# ---------------------
# # Debrujin Indexing.


function setdebrujin!(p::Union{PatVar,PatSegment}, pvars)
p.idx = findfirst((==)(p.name), pvars)
Expand Down
9 changes: 7 additions & 2 deletions src/TermInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ export unsorted_arguments
Returns the number of arguments of `x`. Implicitly defined
if `arguments(x)` is defined.
"""
arity(x) = length(arguments(x))
arity(x)::Int = length(arguments(x))
export arity


Expand Down Expand Up @@ -220,7 +220,7 @@ struct ExprHead
end
export ExprHead

head_symbol(eh::ExprHead) = eh.head
head_symbol(eh::ExprHead)::Symbol = eh.head

istree(x::Expr) = true
head(e::Expr) = ExprHead(e.head)
Expand All @@ -247,6 +247,11 @@ function arguments(e::Expr)
end
end

function arity(e::Expr)::Int
l = length(e.args)
e.head in (:call, :macrocall) ? l - 1 : l
end

function maketerm(head::ExprHead, children; type = Any, metadata = nothing)
if !isempty(children) && first(children) isa Union{Function,DataType}
Expr(head.head, nameof(first(children)), @view(children[2:end])...)
Expand Down
28 changes: 14 additions & 14 deletions src/ematch_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
function predicate_ematcher(p::PatVar, pred)
function predicate_ematcher(next, g, data, bindings)
!islist(data) && return
id::Int = car(data)
id::UInt = car(data)
eclass = g[id]
if pred(eclass)
enode_idx = 0
Expand Down Expand Up @@ -122,27 +122,27 @@ function ematcher(p::PatTerm)
end


const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}()
const EMPTY_BINDINGS = Base.ImmutableDict{Int,Tuple{UInt,Int}}()

"""
Substitutions are efficiently represented in memory as vector of tuples of two integers.
This should allow for static allocation of matches and use of LoopVectorization.jl
The buffer has to be fairly big when e-matching.
The size of the buffer should double when there's too many matches.
The format is as follows
* The first pair denotes the index of the rule in the theory and the e-class id
of the node of the e-graph that is being substituted. The rule number should be negative if it's a bidirectional
the direction is right-to-left.
* From the second pair on, it represents (e-class id, literal position) at the position of the pattern variable
* The end of a substitution is delimited by (0,0)
Substitutions are efficiently represented in memory as immutable dictionaries of tuples of two integers.
The format is as follows:
bindings[0] holds
1. e-class-id of the node of the e-graph that is being substituted.
2. the index of the rule in the theory. The rule number should be negative
if it's a bidirectional rule and the direction is right-to-left.
The rest of the immutable dictionary bindings[n>0] represents (e-class id, literal position) at the position of the pattern variable `n`.
"""
function ematcher_yield(p, npvars::Int, direction::Int)
em = ematcher(p)
function ematcher_yield(g, rule_idx, id)::Int
n_matches = 0
em(g, (id,), EMPTY_ECLASS_DICT) do b, n
em(g, (id,), EMPTY_BINDINGS) do b, n
maybelock!(g) do
push!(g.buffer, assoc(b, 0, (rule_idx * direction, id)))
push!(g.buffer, assoc(b, 0, (id, rule_idx * direction)))
n_matches += 1
end
end
Expand Down
Loading

0 comments on commit 329595e

Please sign in to comment.