Skip to content

Commit

Permalink
add egraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
a committed May 2, 2024
1 parent c4c78c8 commit ef35bf3
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
88 changes: 88 additions & 0 deletions src/egraphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using Metatheory, SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, unflatten, toterm
using Metatheory: @rule

using SymbolicUtils: <

function EGraphs.preprocess(t::BasicSymbolic)
toterm(unflatten(t))
end


"""
Equational rewrite rules for optimizing expressions
"""
opt_theory = @theory a b x y begin
a + b == b + a
a * b == b * a
a * x + a * y == a * (x + y)
-1 * a == -a
a + (-1 * b) == a - b
x^-1 == 1 / x
1 / x * a == a / x
# fraction rules
# (a/b) + (c/b) => (a+c)/b
# trig functions
sin(x) / cos(x) == tan(x)
cos(x) / sin(x) == cot(x)
sin(x)^2 + cos(x)^2 --> 1
sin(2a) == 2sin(a)cos(a)
end


"""
Approximation of costs of operators in number
of CPU cycles required for the numerical computation
See
* https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/
* https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/
* https://github.com/triscale-innov/GFlops.jl
"""
const op_costs = Dict(
(+) => 1,
(-) => 1,
abs => 2,
(*) => 3,
exp => 18,
(/) => 24,
(^) => 100,
log1p => 124,
deg2rad => 125,
rad2deg => 125,
acos => 127,
asind => 128,
acsch => 133,
sin => 134,
cos => 134,
atan => 135,
tan => 156,
)
# TODO some operator costs are in FLOP and not in cycles!!

function costfun(n::VecExpr, op, children_costs::Vector{Float64})::Float64
v_isexpr(n) || return 0 #1
get(op_costs, op, 1) + sum(children_costs)
end


function optimize(ex; params=SaturationParams(timeout=20))
# @show ex
g = EGraph{BasicSymbolic}(ex)
saturate!(g, opt_theory, params)
return extract!(g, costfun)
end



# =======================================================================

@syms x y z a b c

expr = ((a * b) + (a * c)) / ((x * y) + (x * z) )

g = EGraph{BasicSymbolic}(expr)
saturate!(g, opt_theory, SaturationParams())
return extract!(g, costfun)

@benchmark optimize(expr)
6 changes: 3 additions & 3 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ end

unflatten(t) = t

function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata)
basicsymbolic(first(args), args[2:end], type, metadata)
function TermInterface.maketerm(::Type{<:BasicSymbolic}, op, args, type=nothing, metadata=nothing)
basicsymbolic(op, args, type, metadata)
end


Expand Down Expand Up @@ -649,7 +649,7 @@ function similarterm(x, op, args, symtype=nothing; metadata=nothing)
The present call can be replaced by
`maketerm(typeof(x), $(head(x)), [op, args...], symtype, metadata)`""", :similarterm)

TermInterface.maketerm(typeof(x), callhead(x), [op, args...], symtype, metadata)
TermInterface.maketerm(typeof(x), op, args, symtype, metadata)
end

# Old fallback
Expand Down

0 comments on commit ef35bf3

Please sign in to comment.