Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lambda theory test #236

Merged
merged 10 commits into from
Aug 23, 2024
8 changes: 4 additions & 4 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ end
function merge_analysis_data!(a::EClass{D}, b::EClass{D})::Tuple{Bool,Bool,Union{D,Nothing}} where {D}
if !isnothing(a.data) && !isnothing(b.data)
new_a_data = join(a.data, b.data)
(a.data == new_a_data, b.data == new_a_data, new_a_data)
(a.data != new_a_data, b.data != new_a_data, new_a_data)
elseif isnothing(a.data) && !isnothing(b.data)
# a merged, b not merged
(true, false, b.data)
Expand Down Expand Up @@ -504,11 +504,11 @@ upwards merging in an [`EGraph`](@ref). See
the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)
for more details.
"""
function rebuild!(g::EGraph)
function rebuild!(g::EGraph, activate_memo_check=false, activate_analysis_check=false)
n_unions = process_unions!(g)
trimmed_nodes = rebuild_classes!(g)
# @assert check_memo(g)
# @assert check_analysis(g)
@assert !activate_memo_check || check_memo(g)
@assert !activate_analysis_check || check_analysis(g)
g.clean = true

@debug "REBUILT" n_unions trimmed_nodes
Expand Down
4 changes: 3 additions & 1 deletion src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Base.@kwdef mutable struct SaturationParams
schedulerparams::NamedTuple = (;)
threaded::Bool = false
timer::Bool = true
check_memo::Bool = false # activate check of memoization of nodes (hashcons) after rebuilding
check_analysis::Bool = false # activate check of join-semilattice invariant for semantic analysis values after rebuilding
0x0f0f0f marked this conversation as resolved.
Show resolved Hide resolved
end

function cached_ids(g::EGraph, p::PatExpr)::Vector{Id}
Expand Down Expand Up @@ -288,7 +290,7 @@ function eqsat_step!(
if report.reason === nothing && cansaturate(scheduler) && isempty(g.pending)
report.reason = :saturated
end
@timeit report.to "Rebuild" rebuild!(g)
@timeit report.to "Rebuild" rebuild!(g, params.check_memo, params.check_analysis)

Schedulers.rebuild!(scheduler)

Expand Down
65 changes: 51 additions & 14 deletions test/tutorials/lambda_theory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,47 +102,55 @@ end

const LambdaAnalysis = Set{Symbol}

getdata(eclass) = isnothing(eclass.data) ? LambdaAnalysis() : eclass.data
getdata(eclass) = eclass.data

function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr) where {ExprType}
v_isexpr(n) || LambdaAnalysis()
v_isexpr(n) || return LambdaAnalysis()
if v_iscall(n)
h = v_head(n)
op = get_constant(g, h)
args = v_children(n)
eclass = g[args[1]]
free = getdata(eclass)
free = copy(getdata(eclass))

if op == Variable
push!(free, get_constant(g, v_head(eclass.nodes[1])))

elseif op == Let
v, a, b = args[1:3]
v, a, b = args[1:3] # v=a in b
vclass = g[v]
vsy = get_constant(g, v_head(vclass.nodes[1]))
adata = getdata(g[a])
bdata = getdata(g[b])
union!(free, adata)
delete!(free, v)
union!(free, bdata)

delete!(free, vsy)
union!(free, adata)
elseif op == λ
v, b = args[1:2]
vclass = g[v]
vsy = get_constant(g, v_head(vclass.nodes[1]))
bdata = getdata(g[b])
union!(free, bdata)
delete!(free, v)

delete!(free, vsy)
elseif op == Apply
l, v = args[1:2]
ldata = getdata(g[l])
vdata = getdata(g[v])
union!(free, ldata)
union!(free, vdata)

end
return free
end
end

EGraphs.join(from::LambdaAnalysis, to::LambdaAnalysis) = union(from, to)
function EGraphs.join(from::LambdaAnalysis, to::LambdaAnalysis)
if issubset(from, to) # includes case from==to
from
elseif issubset(to, from)
to
else
error("inconsistent free variable sets from: $from to: $to")
end
0x0f0f0f marked this conversation as resolved.
Show resolved Hide resolved
end

function fresh_var_generator()
idx = 0
Expand All @@ -159,6 +167,7 @@ freshvar = fresh_var_generator()
# The final ruleset then looks like below and correctly renames variables when needed:

λT = @theory v e c v1 v2 a b body begin
# let(v,e,body) means let v = e in body
Let(v, e, c::Any) --> c
Let(v1, e, Variable(v1)) --> e
Let(v1, e, Variable(v2)) => v1 == v2 ? e : Variable(v2)
Expand All @@ -177,8 +186,13 @@ x = Variable(:x)
y = Variable(:y)
ex = Apply(λ(:x, λ(:y, Apply(x, y))), y)
g = EGraph{LambdaExpr,LambdaAnalysis}(ex)
saturate!(g, λT)
params = SaturationParams(
timer = false,
check_analysis = true
0x0f0f0f marked this conversation as resolved.
Show resolved Hide resolved
)
saturate!(g, λT, params)
@test λ(:a₄, Apply(y, Variable(:a₄))) == extract!(g, astsize)
@test Set([:y]) == g[g.root].data


# With the above we can implement, for example, Church numerals.
Expand All @@ -200,12 +214,35 @@ params = SaturationParams(
scheduler = Schedulers.BackoffScheduler,
schedulerparams = (match_limit = 6000, ban_length = 5),
timer = false,
check_analysis = true
)
saturate!(g, λT, params)
two_ = extract!(g, astsize)
@test two_ == λ(:a₁, λ(:a₇, Apply(Variable(:a₁), Apply(Variable(:a₁), Variable(:a₇)))))
@test two_ == λ(:x, λ(:y, Apply(Variable(:x), Apply(Variable(:x), Variable(:y)))))
@test g[g.root].data == Set([])
two_

# which is the same as `two` up to $\alpha$-conversion:

two

# check semantic analysis for free variables
function test_free_variable_analysis(expr, free)
g = EGraph{LambdaExpr,LambdaAnalysis}(expr)
g[g.root].data == free
end

@test test_free_variable_analysis(Variable(:x), Set([:x]))
@test test_free_variable_analysis(Apply(Variable(:x), Variable(:y)), Set([:x, :y]))
@test test_free_variable_analysis(λ(:z, Variable(:x)), Set([:x]))
@test test_free_variable_analysis(λ(:z, Variable(:z)), Set{Symbol}())
@test test_free_variable_analysis(λ(:z, λ(:x, Variable(:x))), Set{Symbol}())

let_expr = Let(:x, Variable(:z), λ(:x, Variable(:y)))
@test test_free_variable_analysis(let_expr, Set([:z, :y]))
# after saturation the expression becomes λ(:x, Variable(:y)) where only :y is left as free variable
freshvar = fresh_var_generator()
g = EGraph{LambdaExpr,LambdaAnalysis}(let_expr)
saturate!(g, λT, params)
@test extract!(g, astsize) == λ(:x, Variable(:y))
@test g[g.root].data == Set([:y])