Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements #253

Draft
wants to merge 24 commits into
base: ale/3.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2e1e46d
Output parent lists and check parents after rebuilding
gkronber Oct 10, 2024
1301422
Path splitting procedure to shorten path length with find call.
gkronber Oct 10, 2024
0704eb1
Merge branch 'fix_broken_cas_tests' into performance_improvements
gkronber Oct 10, 2024
500e25e
Store original e-nodes in egraph and keep only e-node ids in parent l…
gkronber Oct 10, 2024
ac4829c
Revert changes to pretty_dict output.
gkronber Oct 10, 2024
6341b3b
Change lookup in classes_by_op dictionary to prevent allocation of a …
gkronber Oct 10, 2024
9fb097c
Fixed implementation of iterate for optbuffer (currently only affecte…
gkronber Oct 10, 2024
80a4696
Fix compile error.
gkronber Oct 10, 2024
68d40d1
Find of eclass_id for enode_id is not necessary here
gkronber Oct 10, 2024
013253e
Add some test assertions for internal datastructures used for egraph …
gkronber Oct 12, 2024
4e6bc9f
Merge branch 'ale/3.0' into performance_improvements
gkronber Oct 12, 2024
9db374d
Set root to allow debugging (requires extraction)
gkronber Oct 13, 2024
3ac0718
isless for VecExpr to allow sorting.
gkronber Oct 13, 2024
90719d0
Check most specific constants first.
gkronber Oct 13, 2024
6ddffa4
Comment and removed unnecessary parentheses.
gkronber Oct 13, 2024
3c07d51
Fixes for constant matching from different PR
gkronber Oct 13, 2024
97c272b
Allow to set SaturationParams for simplify for testing, and mark two …
gkronber Oct 13, 2024
917f3fe
Small fixes.
gkronber Oct 13, 2024
e6f582c
Correct test cases for rebuilding
gkronber Oct 13, 2024
50d6dfd
Complete overhaul of rebuilding mechanism.
gkronber Oct 13, 2024
ee1862e
Fixes, moving forward...
gkronber Oct 14, 2024
348dd62
Bugfix in analysis rebuilding.
gkronber Oct 14, 2024
4d26e3c
Minor changes.
gkronber Oct 14, 2024
0d203a8
Remove nodes vector from egraph and clean-up code.
gkronber Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/prove.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sketch function for basic iterative saturation and extraction
# Sketch function for basic iterative saturation and extraction
function prove(
t,
ex,
Expand Down
90 changes: 61 additions & 29 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

The `modify!` function for EGraph Analysis can optionally modify the eclass
`eclass` after it has been analyzed, typically by adding an e-node.
It should be **idempotent** if no other changes occur to the EClass.
It should be **idempotent** if no other changes occur to the EClass.
(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)).
"""
function modify! end
Expand All @@ -25,7 +25,7 @@ function join end
"""
make(g::EGraph{ExpressionType, AnalysisType}, n::VecExpr)::AnalysisType where {ExpressionType}

Given an e-node `n`, `make` should return the corresponding analysis value.
Given an e-node `n`, `make` should return the corresponding analysis value.
"""
function make end

Expand All @@ -42,7 +42,7 @@ they represent. The [`EGraph`](@ref) itself comes with pretty printing for human
struct EClass{D}
id::Id
nodes::Vector{VecExpr}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not also have this be a vector of Id, if we now store VecExprs in nodes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this could be better. I tested this already, and reverted it again because it introduces an awkward indirection of g.nodes[class.nodes[i]] in several places, most notably the ematching code. Local test showed no performance improvements of only storing the ids here. I'm not yet decided.

My current approach was to store the same nodes objects (VecExpr) in the nodes vectors of eclasses and the egraph. As far as I understand, we do not need to keep the original (uncanonicalized) enodes.

parents::Vector{Pair{VecExpr,Id}}
parents::Vector{Id} # The original Ids of parent enodes.
data::Union{D,Nothing}
end

Expand All @@ -65,10 +65,6 @@ function Base.show(io::IO, a::EClass)
end
end

function addparent!(@nospecialize(a::EClass), n::VecExpr, id::Id)
push!(a.parents, (n => id))
end


function merge_analysis_data!(a::EClass{D}, b::EClass{D})::Tuple{Bool,Bool,Union{D,Nothing}} where {D}
if !isnothing(a.data) && !isnothing(b.data)
Expand Down Expand Up @@ -119,13 +115,15 @@ mutable struct EGraph{ExpressionType,Analysis}
uf::UnionFind
"map from eclass id to eclasses"
classes::Dict{IdKey,EClass{Analysis}}
"vector of the original e-nodes"
nodes::Vector{VecExpr}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if later on we can figure out something more efficient than vectors

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this the vector is quite natural.
New enodes are only pushed at the end, and we can simply index enodes by their non-canonical index.

"hashcons mapping e-nodes to their e-class id"
memo::Dict{VecExpr,Id}
"Hashcons the constants in the e-graph"
constants::Dict{UInt64,Any}
"Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass."
pending::Vector{Pair{VecExpr,Id}}
analysis_pending::UniqueQueue{Pair{VecExpr,Id}}
"Nodes which need to be processed for rebuilding. The id is the id of the e-node, not the canonical id of the e-class."
pending::Vector{Id}
analysis_pending::UniqueQueue{Id}
root::Id
"a cache mapping signatures (function symbols and their arity) to e-classes that contain e-nodes with that function symbol."
classes_by_op::Dict{IdKey,Vector{Id}}
Expand All @@ -144,10 +142,11 @@ function EGraph{ExpressionType,Analysis}(; needslock::Bool = false) where {Expre
EGraph{ExpressionType,Analysis}(
UnionFind(),
Dict{IdKey,EClass{Analysis}}(),
Vector{VecExpr}(),
Dict{VecExpr,Id}(),
Dict{UInt64,Any}(),
Pair{VecExpr,Id}[],
UniqueQueue{Pair{VecExpr,Id}}(),
Id[],
UniqueQueue{Id}(),
0,
Dict{IdKey,Vector{Id}}(),
false,
Expand Down Expand Up @@ -200,7 +199,7 @@ end
function pretty_dict(g::EGraph)
d = Dict{Int,Vector{Any}}()
for (class_id, eclass) in g.classes
d[class_id.val] = map(n -> to_expr(g, n), eclass.nodes)
d[class_id.val] = (map(n -> to_expr(g, n), eclass.nodes))
end
d
end
Expand All @@ -209,8 +208,8 @@ export pretty_dict
function Base.show(io::IO, g::EGraph)
d = pretty_dict(g)
t = "$(typeof(g)) with $(length(d)) e-classes:"
cs = map(sort!(collect(d); by = first)) do (k, vect)
" $k => [$(Base.join(vect, ", "))]"
cs = map(sort!(collect(d); by = first)) do (k, nodes)
" $k => [$(Base.join(nodes, ", "))]"
end
print(io, Base.join([t; cs], "\n"))
end
Expand Down Expand Up @@ -245,7 +244,11 @@ end

function add_class_by_op(g::EGraph, n, eclass_id)
key = IdKey(v_signature(n))
vec = get!(g.classes_by_op, key, Vector{Id}())
vec = get(g.classes_by_op, key, nothing)
if isnothing(vec)
vec = Id[eclass_id]
g.classes_by_op[key] = vec
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not vec = get!(g.classes_by_op, key, Id[eclass_id])?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was that when Id[eclass_id] is supplied as an argument we always create a new vector that we immediately discard when it already exists.
Probably it would be best to use

get!(Vector{Id}, g.classes_by_op, key)

push!(vec, eclass_id)
end

Expand All @@ -263,28 +266,29 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
end

id = push!(g.uf) # create new singleton eclass
push!(g.nodes, n)

if v_isexpr(n)
for c_id in v_children(n)
addparent!(g.classes[IdKey(c_id)], n, id)
push!(g.classes[IdKey(c_id)].parents, id)
end
end

g.memo[n] = id

add_class_by_op(g, n, id)
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n))
eclass = EClass{Analysis}(id, VecExpr[n], Id[], make(g, n)) # TODO: check do we need to copy n for the nodes vector here?
g.classes[IdKey(id)] = eclass
modify!(g, eclass)
push!(g.pending, n => id)
push!(g.pending, id)

return id
end


"""
Extend this function on your types to do preliminary
preprocessing of a symbolic term before adding it to
preprocessing of a symbolic term before adding it to
an EGraph. Most common preprocessing techniques are binarization
of n-ary terms and metadata stripping.
"""
Expand Down Expand Up @@ -407,29 +411,30 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp

while !isempty(g.pending) || !isempty(g.analysis_pending)
while !isempty(g.pending)
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
node = copy(node)
enode_id = pop!(g.pending)
node = copy(g.nodes[enode_id])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the copy necessary here?

Copy link
Collaborator Author

@gkronber gkronber Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. This will still undergo some changes.

canonicalize!(g, node)
old_class_id = get!(g.memo, node, eclass_id)
if old_class_id != eclass_id
did_something = union!(g, old_class_id, eclass_id)
memo_class = get!(g.memo, node, enode_id)
if memo_class != enode_id
did_something = union!(g, memo_class, enode_id)
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
n_unions += did_something
end
end

while !isempty(g.analysis_pending)
(node::VecExpr, eclass_id::Id) = pop!(g.analysis_pending)
eclass_id = find(g, eclass_id)
enode_id = pop!(g.analysis_pending)
node = g.nodes[enode_id]
eclass_id = find(g, enode_id)
eclass_id_key = IdKey(eclass_id)
eclass = g.classes[eclass_id_key]

node_data = make(g, node)
if !isnothing(node_data)
if !isnothing(eclass.data)
joined_data = join(eclass.data, node_data)

if joined_data != eclass.data
g.classes[eclass_id_key] = EClass{AnalysisType}(eclass_id, eclass.nodes, eclass.parents, joined_data)
# eclass.data = joined_data
Expand All @@ -448,6 +453,32 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
n_unions
end

function check_parents(g::EGraph)::Bool
for (id, class) in g.classes
# make sure that the parent node and parent eclass occurs in the parents vector for all children
for n in class.nodes
for chd_id in v_children(n)
chd_class = g[chd_id]
any(nid -> canonicalize!(g, copy(g.nodes[nid])) == n, chd_class.parents) || error("parent node is missing from child_class.parents")
any(nid -> find(g, nid) == id.val, chd_class.parents) || error("missing parent reference from child")
end
end

# make sure all nodes and parent ids occuring in the parent vector have this eclass as a child
for nid in class.parents
parent_class = g[nid]
any(n -> any(ch -> ch == id.val, v_children(n)), parent_class.nodes) || error("no node in the parent references the eclass") # nodes are canonicalized

parent_node = g.nodes[nid]
parent_node_copy = copy(parent_node)
canonicalize!(g, parent_node_copy)
(parent_node_copy in parent_class.nodes) || error("the node from the parent list does not occur in the parent nodes") # might fail because parent_node is probably not canonical
end
end

true
end

function check_memo(g::EGraph)::Bool
test_memo = Dict{VecExpr,Id}()
for (id, class) in g.classes
Expand Down Expand Up @@ -483,9 +514,10 @@ upwards merging in an [`EGraph`](@ref). See
the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)
for more details.
"""
function rebuild!(g::EGraph; should_check_memo=false, should_check_analysis=false)
function rebuild!(g::EGraph; should_check_parents=false, should_check_memo=false, should_check_analysis=false)
n_unions = process_unions!(g)
trimmed_nodes = rebuild_classes!(g)
@assert !should_check_parents || check_parents(g)
@assert !should_check_memo || check_memo(g)
@assert !should_check_analysis || check_analysis(g)
g.clean = true
Expand Down
2 changes: 2 additions & 0 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Base.@kwdef mutable struct SaturationParams
check_memo::Bool = false
"Activate check for join-semilattice invariant for semantic analysis values after rebuilding"
check_analysis::Bool = false
"Activate check for parent vectors"
check_parents::Bool = false
end

function cached_ids(g::EGraph, p::PatExpr)::Vector{Id}
Expand Down
4 changes: 3 additions & 1 deletion src/EGraphs/unionfind.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ function Base.union!(uf::UnionFind, i::Id, j::Id)
end

function find(uf::UnionFind, i::Id)
# path splitting
while i != uf.parents[i]
i = uf.parents[i]
(i, uf.parents[i]) = (uf.parents[i], uf.parents[uf.parents[i]])
end

i
end
2 changes: 1 addition & 1 deletion src/optbuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ end
Base.isempty(b::OptBuffer{T}) where {T} = b.i === 0
Base.empty!(b::OptBuffer{T}) where {T} = (b.i = 0)
@inline Base.length(b::OptBuffer{T}) where {T} = b.i
Base.iterate(b::OptBuffer{T}, i=1) where {T} = iterate(b.v[1:b.i], i)
Base.iterate(b::OptBuffer{T}, i=1) where {T} = i <= b.i ? (b.v[i], i + 1) : nothing
Loading