Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hash collisions in e-graph saturation #229

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ mutable struct EGraph{ExpressionType,Analysis}
uf::UnionFind
"map from eclass id to eclasses"
classes::Dict{IdKey,EClass{Analysis}}
"hashcons mapping e-node hashes to their e-class id"
memo::Dict{IdKey,Id}
"hashcons mapping e-nodes to their e-class id"
memo::Dict{VecExpr,Id}
"Hashcons the constants in the e-graph"
constants::Dict{UInt64,Any}
"Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass."
Expand All @@ -143,7 +143,7 @@ function EGraph{ExpressionType,Analysis}(; needslock::Bool = false) where {Expre
EGraph{ExpressionType,Analysis}(
UnionFind(),
Dict{IdKey,EClass{Analysis}}(),
Dict{IdKey,Id}(),
Dict{VecExpr,Id}(),
Dict{UInt64,Any}(),
Pair{VecExpr,Id}[],
UniqueQueue{Pair{VecExpr,Id}}(),
Expand Down Expand Up @@ -254,7 +254,7 @@ function lookup(g::EGraph, n::VecExpr)::Id
canonicalize!(g, n)
h = IdKey(v_hash(n))

haskey(g.memo, h) ? find(g, g.memo[h]) : 0
haskey(g.memo, n) ? find(g, g.memo[n]) : 0
end


Expand All @@ -272,9 +272,8 @@ Inserts an e-node in an [`EGraph`](@ref)
"""
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
canonicalize!(g, n)
h = IdKey(v_hash(n))

haskey(g.memo, h) && return g.memo[h]
haskey(g.memo, n) && return g.memo[n]

if should_copy
n = copy(n)
Expand All @@ -288,7 +287,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
end
end

g.memo[h] = id
g.memo[n] = id

add_class_by_op(g, n, id)
eclass = EClass{Analysis}(id, VecExpr[n], Pair{VecExpr,Id}[], make(g, n))
Expand Down Expand Up @@ -338,7 +337,7 @@ function addexpr!(g::EGraph, se)::Id
end
n
else # constant enode
Id[Id(0), Id(0), Id(0), add_constant!(g, e)]
VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)])
end
id = add!(g, n, false)
return id
Expand Down Expand Up @@ -432,10 +431,9 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
while !isempty(g.pending)
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
canonicalize!(g, node)
h = IdKey(v_hash(node))
if haskey(g.memo, h)
old_class_id = g.memo[h]
g.memo[h] = eclass_id
if haskey(g.memo, node)
old_class_id = g.memo[node]
g.memo[node] = eclass_id
did_something = union!(g, old_class_id, eclass_id)
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
Expand Down Expand Up @@ -485,7 +483,7 @@ function check_memo(g::EGraph)::Bool

for (node, id) in test_memo
@assert id == find(g, id)
@assert id == find(g, g.memo[IdKey(v_hash(node))])
@assert id == find(g, g.memo[node])
end

true
Expand Down
2 changes: 1 addition & 1 deletion src/Patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ isground(p::AbstractPat) = false
struct PatLiteral <: AbstractPat
value
n::VecExpr
PatLiteral(val) = new(val, Id[0, 0, 0, hash(val)])
PatLiteral(val) = new(val, VecExpr(Id[0, 0, 0, hash(val)]))
end

PatLiteral(p::AbstractPat) = throw(DomainError(p, "Cannot construct a pattern literal of another pattern object."))
Expand Down
53 changes: 34 additions & 19 deletions src/vecexpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,32 @@ export Id,
const Id = UInt64

"""
const VecExpr = Vector{Id}
struct VecExpr
data::Vector{Id}
end

An e-node is a `Vector{Id}` where:
An e-node is represented by `Vector{Id}` where:
* Position 1 stores the hash of the `VecExpr`.
* Position 2 stores the bit flags (`isexpr` or `iscall`).
* Position 3 stores the signature
* Position 4 stores the hash of the `head` (if `isexpr`) or node value in the e-graph constants.
* The rest of the positions store the e-class ids of the children nodes.

The expression is represented as an array of integers to improve performance.
The hash value for the VecExpr is cached in the first position for faster lookup performance in dictionaries.
"""
const VecExpr = Vector{Id}
struct VecExpr
data::Vector{Id}
end

const VECEXPR_FLAG_ISTREE = 0x01
const VECEXPR_FLAG_ISCALL = 0x10
const VECEXPR_META_LENGTH = 4

@inline v_flags(n::VecExpr)::Id = @inbounds n[2]
@inline v_unset_flags!(n::VecExpr) = @inbounds (n[2] = 0)
@inline v_flags(n::VecExpr)::Id = @inbounds n.data[2]
@inline v_unset_flags!(n::VecExpr) = @inbounds (n.data[2] = 0)
@inline v_check_flags(n::VecExpr, flag::Id)::Bool = !iszero(v_flags(n) & flags)
@inline v_set_flag!(n::VecExpr, flag)::Id = @inbounds (n[2] = n[2] | flag)
@inline v_set_flag!(n::VecExpr, flag)::Id = @inbounds (n.data[2] = n.data[2] | flag)

"""Returns `true` if the e-node ID points to a an expression tree."""
@inline v_isexpr(n::VecExpr)::Bool = !iszero(v_flags(n) & VECEXPR_FLAG_ISTREE)
Expand All @@ -56,54 +63,62 @@ const VECEXPR_META_LENGTH = 4
@inline v_iscall(n::VecExpr)::Bool = !iszero(v_flags(n) & VECEXPR_FLAG_ISCALL)

"""Number of children in the e-node."""
@inline v_arity(n::VecExpr)::Int = length(n) - VECEXPR_META_LENGTH
@inline v_arity(n::VecExpr)::Int = length(n.data) - VECEXPR_META_LENGTH

"""
Compute the hash of a `VecExpr` and store it as the first element.
"""
@inline function v_hash!(n::VecExpr)::Id
if iszero(n[1])
n[1] = hash(@view n[2:end])
if iszero(n.data[1])
n.data[1] = hash(@view n.data[2:end])
else
# h = hash(@view n[2:end])
# @assert h == n[1]
n[1]
n.data[1]
end
end

"""The hash of the e-node."""
@inline v_hash(n::VecExpr)::Id = @inbounds n[1]
@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1]
Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here
Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end])

"""Set e-node hash to zero."""
@inline v_unset_hash!(n::VecExpr)::Id = @inbounds (n[1] = Id(0))
@inline v_unset_hash!(n::VecExpr)::Id = @inbounds (n.data[1] = Id(0))

"""E-class IDs of the children of the e-node."""
@inline v_children(n::VecExpr) = @view n[(VECEXPR_META_LENGTH + 1):end]
@inline v_children(n::VecExpr) = @view n.data[(VECEXPR_META_LENGTH + 1):end]

@inline v_signature(n::VecExpr)::Id = @inbounds n[3]
@inline v_signature(n::VecExpr)::Id = @inbounds n.data[3]

@inline v_set_signature!(n::VecExpr, sig::Id) = @inbounds (n[3] = sig)
@inline v_set_signature!(n::VecExpr, sig::Id) = @inbounds (n.data[3] = sig)

"The constant ID of the operation of the e-node, or the e-node ."
@inline v_head(n::VecExpr)::Id = @inbounds n[VECEXPR_META_LENGTH]
@inline v_head(n::VecExpr)::Id = @inbounds n.data[VECEXPR_META_LENGTH]

"Update the E-Node operation ID."
@inline v_set_head!(n::VecExpr, h::Id) = @inbounds (n[VECEXPR_META_LENGTH] = h)
@inline v_set_head!(n::VecExpr, h::Id) = @inbounds (n.data[VECEXPR_META_LENGTH] = h)

"""Construct a new, empty `VecExpr` with `len` children."""
@inline function v_new(len::Int)::VecExpr
n = Vector{Id}(undef, len + VECEXPR_META_LENGTH)
n = VecExpr(Vector{Id}(undef, len + VECEXPR_META_LENGTH))
v_unset_hash!(n)
v_unset_flags!(n)
n
end

@inline v_children_range(n::VecExpr) = ((VECEXPR_META_LENGTH + 1):length(n))
@inline v_children_range(n::VecExpr) = ((VECEXPR_META_LENGTH + 1):length(n.data))


v_pair(a::UInt64, b::UInt64) = UInt128(a) << 64 | b
v_pair_first(p::UInt128)::UInt64 = UInt64(p >> 64)
v_pair_last(p::UInt128)::UInt64 = UInt64(p & 0xffffffffffffffff)

@inline Base.length(n::VecExpr) = length(n.data)
@inline Base.getindex(n::VecExpr, i) = n.data[i]
@inline Base.setindex!(n::VecExpr, val, i) = n.data[i] = val
@inline Base.copy(n::VecExpr) = VecExpr(copy(n.data))
@inline Base.lastindex(n::VecExpr) = lastindex(n.data)
@inline Base.firstindex(n::VecExpr) = firstindex(n.data)

end