Skip to content

Commit

Permalink
Merge pull request #169 from JuliaSymbolics/ale/debugging
Browse files Browse the repository at this point in the history
Add debugging and GraphViz visualization utilities
  • Loading branch information
0x0f0f0f authored Oct 22, 2023
2 parents 39349f9 + ef978aa commit 7828344
Show file tree
Hide file tree
Showing 12 changed files with 413 additions and 58 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ makedocs(
"index.md"
"rewrite.md"
"egraphs.md"
"visualizing.md"
"api.md"
"Tutorials" => tutorials
],
Expand Down
240 changes: 240 additions & 0 deletions docs/src/assets/graphviz.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 41 additions & 0 deletions docs/src/visualizing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Visualizing E-Graphs

You can visualize e-graphs in VSCode by using [GraphViz.jl]()

All you need to do is to install GraphViz.jl and to evaluate an e-graph after including the extra script:

```julia
using GraphViz

include(dirname(pathof(Metatheory)) * "/extras/graphviz.jl")

algebra_rules = @theory a b c begin
a * (b * c) == (a * b) * c
a + (b + c) == (a + b) + c

a + b == b + a
a * (b + c) == (a * b) + (a * c)
(a + b) * c == (a * c) + (b * c)

-a == -1 * a
a - b == a + -b
1 * a == a

0 * a --> 0
a + 0 --> a

a::Number * b == b * a::Number
a::Number * b::Number => a * b
a::Number + b::Number => a + b
end;

ex = :(a - a)
g = EGraph(ex)
params = SaturationParams(; timeout = 2)
saturate!(g, algebra_rules, params)
g
```

And you will see a nice e-graph drawing in the Julia Plots VSCode panel:

![E-Graph Drawing](/assets/graphviz.svg)
2 changes: 1 addition & 1 deletion src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using DataStructures
using TermInterface
using TimerOutputs
using Metatheory:
alwaystrue, cleanast, binarize, @log, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
alwaystrue, cleanast, binarize, DEFAULT_BUFFER_SIZE, BUFFER, BUFFER_LOCK, MERGES_BUF, MERGES_BUF_LOCK, Bindings
using Metatheory.Patterns
using Metatheory.Rules
using Metatheory.EMatchCompiler
Expand Down
8 changes: 0 additions & 8 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,6 @@ end
Inserts an e-node in an [`EGraph`](@ref)
"""
function add!(g::EGraph, n::AbstractENode)::EClassId
@debug("adding ", n)

n = canonicalize(g, n)
haskey(g.memo, n) && return g.memo[n]

Expand Down Expand Up @@ -378,9 +376,6 @@ function Base.merge!(g::EGraph, a::EClassId, b::EClassId)::EClassId

id_a == id_b && return id_a
to = union!(g.uf, id_a, id_b)

@debug "merging" id_a id_b

from = (to == id_a) ? id_b : id_a

push!(g.dirty, to)
Expand Down Expand Up @@ -432,15 +427,13 @@ function repair!(g::EGraph, id::EClassId)
id = find(g, id)
ecdata = g[id]
ecdata.id = id
@debug "repairing " id

new_parents = (length(ecdata.parents) > 30 ? OrderedDict : LittleDict){AbstractENode,EClassId}()

for (p_enode, p_eclass) in ecdata.parents
p_enode = canonicalize!(g, p_enode)
# deduplicate parents
if haskey(new_parents, p_enode)
@debug "merging classes" p_eclass (new_parents[p_enode])
merge!(g, p_eclass, new_parents[p_enode])
end
n_id = find(g, p_eclass)
Expand All @@ -449,7 +442,6 @@ function repair!(g::EGraph, id::EClassId)
end

ecdata.parents = collect(new_parents)
@debug "updated parents " id g.parents[id]

# ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes)

Expand Down
25 changes: 11 additions & 14 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ Base.@kwdef mutable struct SaturationParams
schedulerparams::Tuple = ()
threaded::Bool = false
timer::Bool = true
printiter::Bool = false
end

# function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64}
Expand Down Expand Up @@ -124,17 +123,20 @@ function eqsat_search!(
empty!(BUFFER[])
end

@debug "SEARCHING"
for (rule_idx, rule) in enumerate(theory)
@timeit report.to string(rule_idx) begin
# don't apply banned rules
if !cansearch(scheduler, rule)
@debug "$rule is banned"
continue
end
ids = cached_ids(g, rule.left)
rule isa BidirRule && (ids = ids cached_ids(g, rule.right))
for i in ids
n_matches += rule.ematcher!(g, rule_idx, i)
end
n_matches > 0 && @debug "Rule $rule_idx: $rule produced $n_matches matches"
inform!(scheduler, rule, n_matches)
end
end
Expand Down Expand Up @@ -180,7 +182,7 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::UnequalRule, id::EClas
other_id = instantiate_enode!(bindings, g, pat_to_inst)

if find(g, id) == find(g, other_id)
@log "Contradiction!" rule
@debug "$rule produced a contradiction!"
return :contradiction
end
nothing
Expand All @@ -191,7 +193,7 @@ Instantiate argument for dynamic rule application in e-graph
"""
function instantiate_actual_param!(bindings::Bindings, g::EGraph, i)
ecid, literal_position = bindings[i]
ecid <= 0 && error("unbound pattern variable $pat in rule $rule")
ecid <= 0 && error("unbound pattern variable")
eclass = g[ecid]
if literal_position > 0
@assert eclass[literal_position] isa ENodeLiteral
Expand All @@ -215,10 +217,12 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
i = 0
@assert isempty(MERGES_BUF[])

@debug "APPLYING $(length(BUFFER[])) matches"

lock(BUFFER_LOCK) do
while !isempty(BUFFER[])
if reached(g, params.goal)
@log "Goal reached"
@debug "Goal reached"
rep.reason = :goalreached
return
end
Expand Down Expand Up @@ -249,10 +253,6 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
end



import ..@log


"""
Core algorithm of the library: the equality saturation step.
"""
Expand All @@ -276,6 +276,8 @@ function eqsat_step!(
end
@timeit report.to "Rebuild" rebuild!(g)

@debug smallest_expr = extract!(g, astsize)

return report
end

Expand All @@ -297,7 +299,7 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
while true
curr_iter += 1

params.printiter && @info("iteration ", curr_iter)
@debug "================ EQSAT ITERATION $curr_iter ================"

report = eqsat_step!(g, theory, curr_iter, sched, params, report)

Expand Down Expand Up @@ -328,7 +330,6 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
end
end
report.iterations = curr_iter
@log report

return report
end
Expand All @@ -339,13 +340,9 @@ function areequal(theory::Vector, exprs...; params = SaturationParams())
end

function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params = SaturationParams())
@log "Checking equality for " exprs
if length(exprs) == 1
return true
end
# rebuild!(G)

@log "starting saturation"

n = length(exprs)
ids = map(Base.Fix1(addexpr!, g), collect(exprs))
Expand Down
6 changes: 0 additions & 6 deletions src/Metatheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ using Base.Meta
using Reexport
using TermInterface

macro log(args...)
quote
haskey(ENV, "MT_DEBUG") && @info($(args...))
end |> esc
end

@inline alwaystrue(x) = true

function lookup_pat end
Expand Down
36 changes: 20 additions & 16 deletions src/Syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Remove LineNumberNode from quoted blocks of code
rmlines(e::Expr) = Expr(e.head, map(rmlines, filter(x -> !(x isa LineNumberNode), e.args))...)
rmlines(a) = a

function_object_or_quote(op::Symbol, mod)::Expr = :(isdefined($mod, $(QuoteNode(op))) ? $op : $(QuoteNode(op)))
function_object_or_quote(op, mod) = op

function makesegment(s::Expr, pvars)
if !(exprhead(s) == :(::))
Expand Down Expand Up @@ -84,38 +86,41 @@ function makepattern(ex::Expr, pvars, slots, mod = @__MODULE__, splat = false)
head = exprhead(ex)
op = operation(ex)
# Retrieve the function object if available
# Optionally quote function objects
args = arguments(ex)
istree(op) && (op = makepattern(op, pvars, slots, mod))

if head === :call
if operation(ex) === :(~) # is a variable or segment
if args[1] isa Expr && operation(args[1]) == :(~)
# matches ~~x::predicate or ~~x::predicate...
return makesegment(arguments(args[1])[1], pvars)
elseif splat
# matches ~x::predicate...
return makesegment(args[1], pvars)
else
return makevar(args[1], pvars)
let v = args[1]
if v isa Expr && operation(v) == :(~)
# matches ~~x::predicate or ~~x::predicate...
makesegment(arguments(v)[1], pvars)
elseif splat
# matches ~x::predicate...
makesegment(v, pvars)
else
makevar(v, pvars)
end
end
else # is a term
else # Matches a term
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
return :($PatTerm(:call, $op, [$(patargs...)]))
:($PatTerm(:call, $(function_object_or_quote(op, mod)), [$(patargs...)]))
end

elseif head === :...
makepattern(args[1], pvars, slots, mod, true)
elseif head == :(::) && args[1] in slots
return splat ? makesegment(ex, pvars) : makevar(ex, pvars)
splat ? makesegment(ex, pvars) : makevar(ex, pvars)
elseif head === :ref
# getindex
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
return :($PatTerm(:ref, getindex, [$(patargs...)]))
:($PatTerm(:ref, getindex, [$(patargs...)]))
elseif head === :$
return args[1]
args[1]
else
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
return :($PatTerm($(QuoteNode(head)), $(op isa Symbol ? QuoteNode(op) : op), [$(patargs...)]))
# throw(Meta.ParseError("Unsupported pattern syntax $ex"))
:($PatTerm($(QuoteNode(head)), $(function_object_or_quote(op, mod)), [$(patargs...)]))
end
end

Expand Down Expand Up @@ -328,7 +333,6 @@ macro rule(args...)

e = macroexpand(__module__, expr)
e = rmlines(e)
op = operation(e)
RuleType = rule_sym_map(e)

l, r = arguments(e)
Expand Down
96 changes: 96 additions & 0 deletions src/extras/graphviz.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using GraphViz
using Metatheory
using TermInterface

function render_egraph!(io::IO, g::EGraph)
print(
io,
"""digraph {
compound=true
clusterrank=local
remincross=false
ranksep=0.9
""",
)
for (_, eclass) in g.classes
render_eclass!(io, g, eclass)
end
println(io, "\n}\n")
end

function render_eclass!(io::IO, g::EGraph, eclass::EClass)
print(
io,
""" subgraph cluster_$(eclass.id) {
style="dotted,rounded";
rank=same;
label="#$(eclass.id). Smallest: $(extract!(g, astsize; root=eclass.id))"
fontcolor = gray
fontsize = 8
""",
)

# if g.root == find(g, eclass.id)
# println(io, " penwidth=2")
# end

for (i, node) in enumerate(eclass.nodes)
render_enode_node!(io, g, eclass.id, i, node)
end
print(io, "\n }\n")

for (i, node) in enumerate(eclass.nodes)
render_enode_edges!(io, g, eclass.id, i, node)
end
println(io)
end


function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::AbstractENode)
label = operation(node)
# (mr, style) = if node in diff && get(report.cause, node, missing) !== missing
# pair = get(report.cause, node, missing)
# split(split("$(pair[1].rule) ", "=>")[1], "-->")[1], " color=\"red\""
# else
# " ", ""
# end
# sg *= " $id.$os [label=<$label<br /><font point-size=\"8\" color=\"gray\">$mr</font>> $style];"
println(io, " $eclass_id.$i [label=<$label> shape=box style=rounded]")
end

render_enode_edges!(::IO, ::EGraph, eclass_id, i, ::ENodeLiteral) = nothing

function render_enode_edges!(io::IO, g::EGraph, eclass_id, i, node::ENodeTerm)
len = length(arguments(node))
for (ite, child) in enumerate(arguments(node))
cluster_id = find(g, child)
# The limitation of graphviz is that it cannot point to the eclass outer frame,
# so when pointing to the same e-class, the next best thing is to point to the same e-node.
target_id = "$cluster_id" * (cluster_id == eclass_id ? ".$i" : ".1")

# In order from left to right, if there are more than 3 children, label the order.
dir = if len == 2
ite == 1 ? ":sw" : ":se"
elseif len == 3
ite == 1 ? ":sw" : (ite == 2 ? ":s" : ":se")
else
""
end

linelabel = len > 3 ? " label=$ite" : " "
println(io, " $eclass_id.$i$dir -> $target_id [arrowsize=0.5 lhead=cluster_$cluster_id $linelabel]")
end
end

function Base.convert(::Type{GraphViz.Graph}, g::EGraph)::GraphViz.Graph
io = IOBuffer()
render_egraph!(io, g)
gs = String(take!(io))
g = GraphViz.Graph(gs)
GraphViz.layout!(g; engine = "dot")
g
end

function Base.show(io::IO, mime::MIME"image/svg+xml", g::EGraph)
show(io, mime, convert(GraphViz.Graph, g))
end
Loading

0 comments on commit 7828344

Please sign in to comment.