Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add debugging and GraphViz visualization utilities #169

Merged
merged 13 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ makedocs(
"index.md"
"rewrite.md"
"egraphs.md"
"visualizing.md"
"api.md"
"Tutorials" => tutorials
],
Expand Down
240 changes: 240 additions & 0 deletions docs/src/assets/graphviz.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 41 additions & 0 deletions docs/src/visualizing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Visualizing E-Graphs

You can visualize e-graphs in VSCode by using [GraphViz.jl]()

All you need to do is to install GraphViz.jl and to evaluate an e-graph after including the extra script:

```julia
using GraphViz

include(dirname(pathof(Metatheory)) * "/extras/graphviz.jl")

algebra_rules = @theory a b c begin
a * (b * c) == (a * b) * c
a + (b + c) == (a + b) + c

a + b == b + a
a * (b + c) == (a * b) + (a * c)
(a + b) * c == (a * c) + (b * c)

-a == -1 * a
a - b == a + -b
1 * a == a

0 * a --> 0
a + 0 --> a

a::Number * b == b * a::Number
a::Number * b::Number => a * b
a::Number + b::Number => a + b
end;

ex = :(a - a)
g = EGraph(ex)
params = SaturationParams(; timeout = 2)
saturate!(g, algebra_rules, params)
g
```

And you will see a nice e-graph drawing in the Julia Plots VSCode panel:

![E-Graph Drawing](/assets/graphviz.svg)
2 changes: 1 addition & 1 deletion src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using DataStructures
using TermInterface
using TimerOutputs
using Metatheory:
alwaystrue, cleanast, binarize, @log, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
alwaystrue, cleanast, binarize, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
using Metatheory.Patterns
using Metatheory.Rules
using Metatheory.EMatchCompiler
Expand Down
8 changes: 0 additions & 8 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,6 @@ end
Inserts an e-node in an [`EGraph`](@ref)
"""
function add!(g::EGraph, n::AbstractENode)::EClassId
@debug("adding ", n)

n = canonicalize(g, n)
haskey(g.memo, n) && return g.memo[n]

Expand Down Expand Up @@ -378,9 +376,6 @@ function Base.merge!(g::EGraph, a::EClassId, b::EClassId)::EClassId

id_a == id_b && return id_a
to = union!(g.uf, id_a, id_b)

@debug "merging" id_a id_b

from = (to == id_a) ? id_b : id_a

push!(g.dirty, to)
Expand Down Expand Up @@ -432,15 +427,13 @@ function repair!(g::EGraph, id::EClassId)
id = find(g, id)
ecdata = g[id]
ecdata.id = id
@debug "repairing " id

new_parents = (length(ecdata.parents) > 30 ? OrderedDict : LittleDict){AbstractENode,EClassId}()

for (p_enode, p_eclass) in ecdata.parents
p_enode = canonicalize!(g, p_enode)
# deduplicate parents
if haskey(new_parents, p_enode)
@debug "merging classes" p_eclass (new_parents[p_enode])
merge!(g, p_eclass, new_parents[p_enode])
end
n_id = find(g, p_eclass)
Expand All @@ -449,7 +442,6 @@ function repair!(g::EGraph, id::EClassId)
end

ecdata.parents = collect(new_parents)
@debug "updated parents " id g.parents[id]

# ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes)

Expand Down
25 changes: 11 additions & 14 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ Base.@kwdef mutable struct SaturationParams
schedulerparams::Tuple = ()
threaded::Bool = false
timer::Bool = true
printiter::Bool = false
end

# function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64}
Expand Down Expand Up @@ -124,17 +123,20 @@ function eqsat_search!(
empty!(BUFFER[])
end

@debug "SEARCHING"
for (rule_idx, rule) in enumerate(theory)
@timeit report.to string(rule_idx) begin
# don't apply banned rules
if !cansearch(scheduler, rule)
@debug "$rule is banned"
continue
end
ids = cached_ids(g, rule.left)
rule isa BidirRule && (ids = ids ∪ cached_ids(g, rule.right))
for i in ids
n_matches += rule.ematcher!(g, rule_idx, i)
end
n_matches > 0 && @debug "Rule $rule_idx: $rule produced $n_matches matches"
inform!(scheduler, rule, n_matches)
end
end
Expand Down Expand Up @@ -180,7 +182,7 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::UnequalRule, id::EClas
other_id = instantiate_enode!(bindings, g, pat_to_inst)

if find(g, id) == find(g, other_id)
@log "Contradiction!" rule
@debug "$rule produced a contradiction!"
return :contradiction
end
nothing
Expand All @@ -191,7 +193,7 @@ Instantiate argument for dynamic rule application in e-graph
"""
function instantiate_actual_param!(bindings::Bindings, g::EGraph, i)
ecid, literal_position = bindings[i]
ecid <= 0 && error("unbound pattern variable $pat in rule $rule")
ecid <= 0 && error("unbound pattern variable")
eclass = g[ecid]
if literal_position > 0
@assert eclass[literal_position] isa ENodeLiteral
Expand All @@ -215,10 +217,12 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
i = 0
@assert isempty(MERGES_BUF[])

@debug "APPLYING $(length(BUFFER[])) matches"

lock(BUFFER_LOCK) do
while !isempty(BUFFER[])
if reached(g, params.goal)
@log "Goal reached"
@debug "Goal reached"
rep.reason = :goalreached
return
end
Expand Down Expand Up @@ -249,10 +253,6 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
end



import ..@log


"""
Core algorithm of the library: the equality saturation step.
"""
Expand All @@ -276,6 +276,8 @@ function eqsat_step!(
end
@timeit report.to "Rebuild" rebuild!(g)

@debug smallest_expr = extract!(g, astsize)

return report
end

Expand All @@ -297,7 +299,7 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
while true
curr_iter += 1

params.printiter && @info("iteration ", curr_iter)
@debug "================ EQSAT ITERATION $curr_iter ================"

report = eqsat_step!(g, theory, curr_iter, sched, params, report)

Expand Down Expand Up @@ -328,7 +330,6 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
end
end
report.iterations = curr_iter
@log report

return report
end
Expand All @@ -339,13 +340,9 @@ function areequal(theory::Vector, exprs...; params = SaturationParams())
end

function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params = SaturationParams())
@log "Checking equality for " exprs
if length(exprs) == 1
return true
end
# rebuild!(G)

@log "starting saturation"

n = length(exprs)
ids = map(Base.Fix1(addexpr!, g), collect(exprs))
Expand Down
6 changes: 0 additions & 6 deletions src/Metatheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ using Base.Meta
using Reexport
using TermInterface

macro log(args...)
quote
haskey(ENV, "MT_DEBUG") && @info($(args...))
end |> esc
end

@inline alwaystrue(x) = true

function lookup_pat end
Expand Down
36 changes: 20 additions & 16 deletions src/Syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Remove LineNumberNode from quoted blocks of code
rmlines(e::Expr) = Expr(e.head, map(rmlines, filter(x -> !(x isa LineNumberNode), e.args))...)
rmlines(a) = a

function_object_or_quote(op::Symbol, mod)::Expr = :(isdefined($mod, $(QuoteNode(op))) ? $op : $(QuoteNode(op)))
function_object_or_quote(op, mod) = op

function makesegment(s::Expr, pvars)
if !(exprhead(s) == :(::))
Expand Down Expand Up @@ -84,38 +86,41 @@ function makepattern(ex::Expr, pvars, slots, mod = @__MODULE__, splat = false)
head = exprhead(ex)
op = operation(ex)
# Retrieve the function object if available
# Optionally quote function objects
args = arguments(ex)
istree(op) && (op = makepattern(op, pvars, slots, mod))

if head === :call
if operation(ex) === :(~) # is a variable or segment
if args[1] isa Expr && operation(args[1]) == :(~)
# matches ~~x::predicate or ~~x::predicate...
return makesegment(arguments(args[1])[1], pvars)
elseif splat
# matches ~x::predicate...
return makesegment(args[1], pvars)
else
return makevar(args[1], pvars)
let v = args[1]
if v isa Expr && operation(v) == :(~)
# matches ~~x::predicate or ~~x::predicate...
makesegment(arguments(v)[1], pvars)
elseif splat
# matches ~x::predicate...
makesegment(v, pvars)
else
makevar(v, pvars)
end
end
else # is a term
else # Matches a term
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
return :($PatTerm(:call, $op, [$(patargs...)]))
:($PatTerm(:call, $(function_object_or_quote(op, mod)), [$(patargs...)]))
end

elseif head === :...
makepattern(args[1], pvars, slots, mod, true)
elseif head == :(::) && args[1] in slots
return splat ? makesegment(ex, pvars) : makevar(ex, pvars)
splat ? makesegment(ex, pvars) : makevar(ex, pvars)
elseif head === :ref
# getindex
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
return :($PatTerm(:ref, getindex, [$(patargs...)]))
:($PatTerm(:ref, getindex, [$(patargs...)]))
elseif head === :$
return args[1]
args[1]
else
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
return :($PatTerm($(QuoteNode(head)), $(op isa Symbol ? QuoteNode(op) : op), [$(patargs...)]))
# throw(Meta.ParseError("Unsupported pattern syntax $ex"))
:($PatTerm($(QuoteNode(head)), $(function_object_or_quote(op, mod)), [$(patargs...)]))
end
end

Expand Down Expand Up @@ -328,7 +333,6 @@ macro rule(args...)

e = macroexpand(__module__, expr)
e = rmlines(e)
op = operation(e)
RuleType = rule_sym_map(e)

l, r = arguments(e)
Expand Down
96 changes: 96 additions & 0 deletions src/extras/graphviz.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using GraphViz
using Metatheory
using TermInterface

function render_egraph!(io::IO, g::EGraph)
print(
io,
"""digraph {
compound=true
clusterrank=local
remincross=false
ranksep=0.9
""",
)
for (_, eclass) in g.classes
render_eclass!(io, g, eclass)
end
println(io, "\n}\n")
end

function render_eclass!(io::IO, g::EGraph, eclass::EClass)
print(
io,
""" subgraph cluster_$(eclass.id) {
style="dotted,rounded";
rank=same;
label="#$(eclass.id). Smallest: $(extract!(g, astsize; root=eclass.id))"
fontcolor = gray
fontsize = 8
""",
)

# if g.root == find(g, eclass.id)
# println(io, " penwidth=2")
# end

for (i, node) in enumerate(eclass.nodes)
render_enode_node!(io, g, eclass.id, i, node)
end
print(io, "\n }\n")

for (i, node) in enumerate(eclass.nodes)
render_enode_edges!(io, g, eclass.id, i, node)
end
println(io)
end


function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::AbstractENode)
label = operation(node)
# (mr, style) = if node in diff && get(report.cause, node, missing) !== missing
# pair = get(report.cause, node, missing)
# split(split("$(pair[1].rule) ", "=>")[1], "-->")[1], " color=\"red\""
# else
# " ", ""
# end
# sg *= " $id.$os [label=<$label<br /><font point-size=\"8\" color=\"gray\">$mr</font>> $style];"
println(io, " $eclass_id.$i [label=<$label> shape=box style=rounded]")
end

render_enode_edges!(::IO, ::EGraph, eclass_id, i, ::ENodeLiteral) = nothing

function render_enode_edges!(io::IO, g::EGraph, eclass_id, i, node::ENodeTerm)
len = length(arguments(node))
for (ite, child) in enumerate(arguments(node))
cluster_id = find(g, child)
# The limitation of graphviz is that it cannot point to the eclass outer frame,
# so when pointing to the same e-class, the next best thing is to point to the same e-node.
target_id = "$cluster_id" * (cluster_id == eclass_id ? ".$i" : ".1")

# In order from left to right, if there are more than 3 children, label the order.
dir = if len == 2
ite == 1 ? ":sw" : ":se"
elseif len == 3
ite == 1 ? ":sw" : (ite == 2 ? ":s" : ":se")
else
""
end

linelabel = len > 3 ? " label=$ite" : " "
println(io, " $eclass_id.$i$dir -> $target_id [arrowsize=0.5 lhead=cluster_$cluster_id $linelabel]")
end
end

function Base.convert(::Type{GraphViz.Graph}, g::EGraph)::GraphViz.Graph
io = IOBuffer()
render_egraph!(io, g)
gs = String(take!(io))
g = GraphViz.Graph(gs)
GraphViz.layout!(g; engine = "dot")
g
end

function Base.show(io::IO, mime::MIME"image/svg+xml", g::EGraph)
show(io, mime, convert(GraphViz.Graph, g))
end
Loading
Loading