Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix enode memoization #238

Closed
wants to merge 11 commits into from
5 changes: 1 addition & 4 deletions .github/workflows/benchmark_pr.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: Benchmark pull request

on:
pull_request:
branches:
- master
Comment on lines -3 to -6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also changed in master i think

on: [pull_request]

permissions:
pull-requests: write
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
name: CI
on:
pull_request:
branches:
- master
push:
branches:
- master
Expand Down
84 changes: 29 additions & 55 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,38 +223,20 @@ Returns the canonical e-class id for a given e-class.

@inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))]

# function canonicalize(g::EGraph, n::VecExpr)::VecExpr
# if !v_isexpr(n)
# v_hash!(n)
# return n
# end
# l = v_arity(n)
# new_n = v_new(l)
# v_set_flag!(new_n, v_flags(n))
# v_set_head!(new_n, v_head(n))
# for i in v_children_range(n)
# @inbounds new_n[i] = find(g, n[i])
# end
# v_hash!(new_n)
# new_n
# end

function canonicalize!(g::EGraph, n::VecExpr)
v_isexpr(n) || @goto ret
for i in (VECEXPR_META_LENGTH + 1):length(n)
@inbounds n[i] = find(g, n[i])
if v_isexpr(n)
for i in (VECEXPR_META_LENGTH + 1):length(n)
@inbounds n[i] = find(g, n[i])
end
end
v_unset_hash!(n)
@label ret
v_hash!(n)
Comment on lines -247 to -249
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the hash going to mutate as well? What is the difference from caching it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, the main issue is that VecExpr must not be updated after they have been added to memo. Caching of hash values is an independent concern.
I did a quick analysis, in which I checked how often the cached values are actually used and I saw only a small usage factor. I'll redo the analysis more carefully and post the result here.

The cached value can make up 15% to 20% of the memory required for VecExpr.

Copy link
Collaborator Author

@gkronber gkronber Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the annotations in https://github.com/gkronber/Metatheory.jl/tree/count_vecexpr_hash_calls
(gkronber@593d9e2)

julia> using Metatheory
Precompiling Metatheory
  1 dependency successfully precompiled in 2 seconds. 6 already precompiled.

julia> include("benchmark/benchmarks.jl")
Benchmark(evals=1, seconds=5.0, samples=10000)

julia> run(SUITE)
[...]

julia> Metatheory.VecExprModule.vexpr_created
12355173
julia> Metatheory.VecExprModule.v_copy_calls
5247759
julia> Metatheory.VecExprModule.v_new_calls
3975935
julia> Metatheory.VecExprModule.unset_hash_calls
41556639
julia> Metatheory.VecExprModule.hash_calls
93167599
julia> Metatheory.VecExprModule.cached_hash_computation
40712170
julia> Metatheory.VecExprModule.cached_hash_access
1205819
julia> Metatheory.EGraphs.memo_lookups
31998547
julia> Metatheory.EGraphs.memo_add
15708295

In this run of the benchmarks:

  • 12mio VecExpr objects are constructed (some via v_new, some via copy, most via direct constructor calls)
  • 32mio lookups are done in memo (note that most of them call the hash function twice via haskey(n) && g.memo[n]).
  • 15mio times g.memo[n] is set (memo_add)
  • memo lookups and adding to memo cause 93mio calls to hash(n::VecExpr)
  • v_hash! is called 42mio times (40.7mio times the hash value is calculated, 1.2mio times the cached value is returned)
  • -> only a bit more than half of the hash calls can use the cached value
  • if instead of haskey(n) && g.memo[n] we use get we should be able to reduce the number of hash calls significantly (exact number to be added in a later edit).
  • improving memo lookups we have only 64mio calls to hash(n::VecExpr). The other numbers are unchanged. We still need to calculate the hash 40.7mio times. (65346c7)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VecExpr must not be updated after they have been added to memo.

I think this happens as well in egg and is part of the algorithm. We should ask the egg community how they are doing it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In egg they are careful to clone nodes before adding them to memo.

n
end

function lookup(g::EGraph, n::VecExpr)::Id
canonicalize!(g, n)
h = IdKey(v_hash(n))

haskey(g.memo, n) ? find(g, g.memo[n]) : 0
id = get(g.memo, n, zero(Id))
iszero(id) ? id : find(g, id)
end


Expand All @@ -272,9 +254,10 @@ Inserts an e-node in an [`EGraph`](@ref)
"""
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
canonicalize!(g, n)

haskey(g.memo, n) && return g.memo[n]


id = get(g.memo, n, zero(Id))
iszero(id) || return id

if should_copy
n = copy(n)
end
Expand Down Expand Up @@ -319,28 +302,21 @@ function addexpr!(g::EGraph, se)::Id
se isa EClass && return se.id
e = preprocess(se)

n = if isexpr(e)
args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)

h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))

# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))

for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
n
else # constant enode
VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)])
isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), add_constant!(g, e)]), false) # constant enode

args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)
h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))
# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))
for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
id = add!(g, n, false)
return id
add!(g, n, false)
end

"""
Expand Down Expand Up @@ -431,9 +407,8 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
while !isempty(g.pending)
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
canonicalize!(g, node)
if haskey(g.memo, node)
old_class_id = g.memo[node]
g.memo[node] = eclass_id
old_class_id = get!(g.memo, node, eclass_id)
if old_class_id != eclass_id
did_something = union!(g, old_class_id, eclass_id)
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
Expand Down Expand Up @@ -473,17 +448,16 @@ function check_memo(g::EGraph)::Bool
for (id, class) in g.classes
@assert id.val == class.id
for node in class.nodes
if haskey(test_memo, node)
old_id = test_memo[node]
test_memo[node] = id.val
old_id = get!(test_memo, node, id.val)
if old_id != id.val
@assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)"
end
end
end

for (node, id) in test_memo
@assert id == find(g, id)
@assert id == find(g, g.memo[node])
@assert id == find(g, g.memo[node]) "Entry for $node at $id in test_memo was incorrect."
end

true
Expand Down
2 changes: 1 addition & 1 deletion src/Patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ isground(p::AbstractPat) = false
struct PatLiteral <: AbstractPat
value
n::VecExpr
PatLiteral(val) = new(val, VecExpr(Id[0, 0, 0, hash(val)]))
PatLiteral(val) = new(val, VecExpr(Id[0, 0, hash(val)]))
end

PatLiteral(p::AbstractPat) = throw(DomainError(p, "Cannot construct a pattern literal of another pattern object."))
Expand Down
43 changes: 11 additions & 32 deletions src/vecexpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,25 @@ const Id = UInt64
end

An e-node is represented by `Vector{Id}` where:
* Position 1 stores the hash of the `VecExpr`.
* Position 2 stores the bit flags (`isexpr` or `iscall`).
* Position 3 stores the signature
* Position 4 stores the hash of the `head` (if `isexpr`) or node value in the e-graph constants.
* Position 1 stores the bit flags (`isexpr` or `iscall`).
* Position 2 stores the signature
* Position 3 stores the hash of the `head` (if `isexpr`) or node value in the e-graph constants.
* The rest of the positions store the e-class ids of the children nodes.

The expression is represented as an array of integers to improve performance.
The hash value for the VecExpr is cached in the first position for faster lookup performance in dictionaries.
"""
struct VecExpr
data::Vector{Id}
end

const VECEXPR_FLAG_ISTREE = 0x01
const VECEXPR_FLAG_ISCALL = 0x10
const VECEXPR_META_LENGTH = 4
const VECEXPR_META_LENGTH = 3

@inline v_flags(n::VecExpr)::Id = @inbounds n.data[2]
@inline v_unset_flags!(n::VecExpr) = @inbounds (n.data[2] = 0)
@inline v_flags(n::VecExpr)::Id = @inbounds n.data[1]
@inline v_unset_flags!(n::VecExpr) = @inbounds (n.data[1] = 0)
@inline v_check_flags(n::VecExpr, flag::Id)::Bool = !iszero(v_flags(n) & flags)
@inline v_set_flag!(n::VecExpr, flag)::Id = @inbounds (n.data[2] = n.data[2] | flag)
@inline v_set_flag!(n::VecExpr, flag)::Id = @inbounds (n.data[1] = n.data[1] | flag)

"""Returns `true` if the e-node ID points to a an expression tree."""
@inline v_isexpr(n::VecExpr)::Bool = !iszero(v_flags(n) & VECEXPR_FLAG_ISTREE)
Expand All @@ -65,33 +63,15 @@ const VECEXPR_META_LENGTH = 4
"""Number of children in the e-node."""
@inline v_arity(n::VecExpr)::Int = length(n.data) - VECEXPR_META_LENGTH

"""
Compute the hash of a `VecExpr` and store it as the first element.
"""
@inline function v_hash!(n::VecExpr)::Id
if iszero(n.data[1])
n.data[1] = hash(@view n.data[2:end])
else
# h = hash(@view n[2:end])
# @assert h == n[1]
n.data[1]
end
end

"""The hash of the e-node."""
@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1]
Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here
Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end])

"""Set e-node hash to zero."""
@inline v_unset_hash!(n::VecExpr)::Id = @inbounds (n.data[1] = Id(0))
Base.hash(n::VecExpr) = hash(n.data)
Base.:(==)(a::VecExpr, b::VecExpr) = a.data == b.data

"""E-class IDs of the children of the e-node."""
@inline v_children(n::VecExpr) = @view n.data[(VECEXPR_META_LENGTH + 1):end]

@inline v_signature(n::VecExpr)::Id = @inbounds n.data[3]
@inline v_signature(n::VecExpr)::Id = @inbounds n.data[2]

@inline v_set_signature!(n::VecExpr, sig::Id) = @inbounds (n.data[3] = sig)
@inline v_set_signature!(n::VecExpr, sig::Id) = @inbounds (n.data[2] = sig)

"The constant ID of the operation of the e-node, or the e-node ."
@inline v_head(n::VecExpr)::Id = @inbounds n.data[VECEXPR_META_LENGTH]
Expand All @@ -102,7 +82,6 @@ Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:en
"""Construct a new, empty `VecExpr` with `len` children."""
@inline function v_new(len::Int)::VecExpr
n = VecExpr(Vector{Id}(undef, len + VECEXPR_META_LENGTH))
v_unset_hash!(n)
v_unset_flags!(n)
n
end
Expand Down
Loading