From 0c2a4505c61f34c2a18bf21089dc215115025c94 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Thu, 27 Jun 2024 09:47:12 +0200 Subject: [PATCH] adjust e-graph integration --- test/egraphs.jl | 109 ++++++++++++++++++++++++------------------------ test/rewrite.jl | 2 + 2 files changed, 56 insertions(+), 55 deletions(-) diff --git a/test/egraphs.jl b/test/egraphs.jl index a5801cdaa..d139d316c 100644 --- a/test/egraphs.jl +++ b/test/egraphs.jl @@ -61,26 +61,26 @@ opt_theory = @theory a b c x y z begin sin(x)*cos(y) - cos(x)*sin(y) --> sin(x - y) # hyperbolic trigonometric # are these optimizing at all? dont think so - # sinh(x) == (ℯ^x - ℯ^(-x))/2 - # csch(x) == 1/sinh(x) - # cosh(x) == (ℯ^x + ℯ^(-x))/2 - # sech(x) == 1/cosh(x) - # sech(x) == 2/(ℯ^x + ℯ^(-x)) - # tanh(x) == sinh(x)/cosh(x) - # tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x)) - # coth(x) == 1/tanh(x) - # coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x)) - - # cosh(x)^2 - sinh(x)^2 --> 1 - # tanh(x)^2 + sech(x)^2 --> 1 - # coth(x)^2 - csch(x)^2 --> 1 - - # asinh(z) == log(z + √(z^2 + 1)) - # acosh(z) == log(z + √(z^2 - 1)) - # atanh(z) == log((1+z)/(1-z))/2 - # acsch(z) == log((1+√(1+z^2)) / z ) - # asech(z) == log((1 + √(1-z^2)) / z ) - # acoth(z) == log( (z+1)/(z-1) )/2 + sinh(x) == (ℯ^x - ℯ^(-x))/2 + csch(x) == 1/sinh(x) + cosh(x) == (ℯ^x + ℯ^(-x))/2 + sech(x) == 1/cosh(x) + sech(x) == 2/(ℯ^x + ℯ^(-x)) + tanh(x) == sinh(x)/cosh(x) + tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x)) + coth(x) == 1/tanh(x) + coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x)) + + cosh(x)^2 - sinh(x)^2 --> 1 + tanh(x)^2 + sech(x)^2 --> 1 + coth(x)^2 - csch(x)^2 --> 1 + + asinh(z) == log(z + √(z^2 + 1)) + acosh(z) == log(z + √(z^2 - 1)) + atanh(z) == log((1+z)/(1-z))/2 + acsch(z) == log((1+√(1+z^2)) / z ) + asech(z) == log((1 + √(1-z^2)) / z ) + acoth(z) == log( (z+1)/(z-1) )/2 # folding x::Number * y::Number => x*y @@ -88,13 +88,6 @@ opt_theory = @theory a b c x y z begin x::Number / y::Number => x/y x::Number - y::Number => x-y end -# opt_theory = @theory a b c x y begin -# a * x == x * a -# a * x + a * y == a*(x+y) -# -1 * a == -a -# a + -b --> a - b -# -b + a --> b - a -# end # See @@ -103,41 +96,46 @@ end # * https://github.com/triscale-innov/GFlops.jl # Measure the cost of expressions in terms of number of ASM instructions -const op_costs = Dict() -const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)] - -const io = IOBuffer() - -for f in vcat(monadic, [-]) - z = get!(op_costs, nameof(f), Dict()) - for (t, at) in types - try - InteractiveUtils.code_native(io, f, (t,)) - catch e - z[(t,)] = z[(at,)] = 1 - continue +function make_op_costs() + const op_costs = Dict() + + const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)] + + const io = IOBuffer() + + for f in vcat(monadic, [-]) + z = get!(op_costs, nameof(f), Dict()) + for (t, at) in types + try + InteractiveUtils.code_native(io, f, (t,)) + catch e + z[(t,)] = z[(at,)] = 1 + continue + end + str = String(take!(io)) + z[(t,)] = z[(at,)] = length(split(str, "\n")) end - str = String(take!(io)) - z[(t,)] = z[(at,)] = length(split(str, "\n")) end -end - -for f in vcat(diadic, [+, -, *, /, //, ^]) - z = get!(op_costs, nameof(f), Dict()) - for (t1, at1) in types, (t2, at2) in types - try - InteractiveUtils.code_native(io, f, (t1, t2)) - catch e - z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1 - continue + + for f in vcat(diadic, [+, -, *, /, //, ^]) + z = get!(op_costs, nameof(f), Dict()) + for (t1, at1) in types, (t2, at2) in types + try + InteractiveUtils.code_native(io, f, (t1, t2)) + catch e + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1 + continue + end + str = String(take!(io)) + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n")) end - str = String(take!(io)) - z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n")) end + + op_costs end -function getopcost(f::Function, types::Tuple) +function getopcost(op_costs, f::Function, types::Tuple) sym = nameof(f) if haskey(op_costs, sym) && haskey(op_costs[sym], types) return op_costs[sym][types] @@ -175,6 +173,7 @@ denoisescalars(x, atol=1e-11) = Postwalk(Chain([ @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol) ]))(x) +const op_costs = make_op_costs() function optimize(ex::Symbolic; params=SaturationParams(), atol=1e-13, verbose=false, kws...) # ex = simplify(denoisescalars(ex, atol)) # println(ex) diff --git a/test/rewrite.jl b/test/rewrite.jl index 8f4304ace..ccc754141 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,5 +1,7 @@ @syms a b c +using Metatheory + @testset "Equality" begin @eqtest a == a @eqtest a != b