Skip to content

Commit

Permalink
adjust e-graph integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Cheli committed Jun 27, 2024
1 parent e338734 commit 0c2a450
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 55 deletions.
109 changes: 54 additions & 55 deletions test/egraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,40 +61,33 @@ opt_theory = @theory a b c x y z begin
sin(x)*cos(y) - cos(x)*sin(y) --> sin(x - y)
# hyperbolic trigonometric
# are these optimizing at all? dont think so
# sinh(x) == (ℯ^x - ℯ^(-x))/2
# csch(x) == 1/sinh(x)
# cosh(x) == (ℯ^x + ℯ^(-x))/2
# sech(x) == 1/cosh(x)
# sech(x) == 2/(ℯ^x + ℯ^(-x))
# tanh(x) == sinh(x)/cosh(x)
# tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x))
# coth(x) == 1/tanh(x)
# coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x))

# cosh(x)^2 - sinh(x)^2 --> 1
# tanh(x)^2 + sech(x)^2 --> 1
# coth(x)^2 - csch(x)^2 --> 1

# asinh(z) == log(z + √(z^2 + 1))
# acosh(z) == log(z + √(z^2 - 1))
# atanh(z) == log((1+z)/(1-z))/2
# acsch(z) == log((1+√(1+z^2)) / z )
# asech(z) == log((1 + √(1-z^2)) / z )
# acoth(z) == log( (z+1)/(z-1) )/2
sinh(x) == (ℯ^x -^(-x))/2
csch(x) == 1/sinh(x)
cosh(x) == (ℯ^x +^(-x))/2
sech(x) == 1/cosh(x)
sech(x) == 2/(ℯ^x +^(-x))
tanh(x) == sinh(x)/cosh(x)
tanh(x) == (ℯ^x -^(-x))/(ℯ^x +^(-x))
coth(x) == 1/tanh(x)
coth(x) == (ℯ^x +^-x)/(ℯ^x -^(-x))

cosh(x)^2 - sinh(x)^2 --> 1
tanh(x)^2 + sech(x)^2 --> 1
coth(x)^2 - csch(x)^2 --> 1

asinh(z) == log(z + (z^2 + 1))
acosh(z) == log(z + (z^2 - 1))
atanh(z) == log((1+z)/(1-z))/2
acsch(z) == log((1+√(1+z^2)) / z )
asech(z) == log((1 + (1-z^2)) / z )
acoth(z) == log( (z+1)/(z-1) )/2

# folding
x::Number * y::Number => x*y
x::Number + y::Number => x+y
x::Number / y::Number => x/y
x::Number - y::Number => x-y
end
# opt_theory = @theory a b c x y begin
# a * x == x * a
# a * x + a * y == a*(x+y)
# -1 * a == -a
# a + -b --> a - b
# -b + a --> b - a
# end


# See
Expand All @@ -103,41 +96,46 @@ end
# * https://github.com/triscale-innov/GFlops.jl
# Measure the cost of expressions in terms of number of ASM instructions

const op_costs = Dict()

const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)]

const io = IOBuffer()

for f in vcat(monadic, [-])
z = get!(op_costs, nameof(f), Dict())
for (t, at) in types
try
InteractiveUtils.code_native(io, f, (t,))
catch e
z[(t,)] = z[(at,)] = 1
continue
function make_op_costs()
const op_costs = Dict()

const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)]

const io = IOBuffer()

for f in vcat(monadic, [-])
z = get!(op_costs, nameof(f), Dict())
for (t, at) in types
try
InteractiveUtils.code_native(io, f, (t,))
catch e
z[(t,)] = z[(at,)] = 1
continue
end
str = String(take!(io))
z[(t,)] = z[(at,)] = length(split(str, "\n"))
end
str = String(take!(io))
z[(t,)] = z[(at,)] = length(split(str, "\n"))
end
end

for f in vcat(diadic, [+, -, *, /, //, ^])
z = get!(op_costs, nameof(f), Dict())
for (t1, at1) in types, (t2, at2) in types
try
InteractiveUtils.code_native(io, f, (t1, t2))
catch e
z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1
continue

for f in vcat(diadic, [+, -, *, /, //, ^])
z = get!(op_costs, nameof(f), Dict())
for (t1, at1) in types, (t2, at2) in types
try
InteractiveUtils.code_native(io, f, (t1, t2))
catch e
z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1
continue
end
str = String(take!(io))
z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n"))
end
str = String(take!(io))
z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n"))
end

op_costs
end

function getopcost(f::Function, types::Tuple)
function getopcost(op_costs, f::Function, types::Tuple)
sym = nameof(f)
if haskey(op_costs, sym) && haskey(op_costs[sym], types)
return op_costs[sym][types]
Expand Down Expand Up @@ -175,6 +173,7 @@ denoisescalars(x, atol=1e-11) = Postwalk(Chain([
@acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol)
]))(x)

const op_costs = make_op_costs()
function optimize(ex::Symbolic; params=SaturationParams(), atol=1e-13, verbose=false, kws...)
# ex = simplify(denoisescalars(ex, atol))
# println(ex)
Expand Down
2 changes: 2 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
@syms a b c

using Metatheory

@testset "Equality" begin
@eqtest a == a
@eqtest a != b
Expand Down

0 comments on commit 0c2a450

Please sign in to comment.