Skip to content

Commit

Permalink
building
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Cheli committed Jun 24, 2024
1 parent 7b72cf1 commit e9ebd8f
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 63 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.10, 1.0, 2"
StaticArrays = "0.12, 1.0"
SymbolicIndexingInterface = "0.3"
TermInterface = "1.0.1"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
julia = "1.3"
Expand Down
2 changes: 1 addition & 1 deletion src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using TermInterface
import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm

export operation, arguments, unsorted_arguments, iscall
export operation, arguments, sorted_arguments, iscall
# Sym, Term,
# Add, Mul and Pow
include("types.jl")
Expand Down
20 changes: 10 additions & 10 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
symtype, unsorted_arguments, metadata, isterm, term
symtype, sorted_arguments, metadata, isterm, term

##== state management ==##

Expand Down Expand Up @@ -115,7 +115,7 @@ function function_to_expr(op, O, st)
(get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing
name = nameof(op)
fun = GlobalRef(NaNMath, name)
args = map(Base.Fix2(toexpr, st), arguments(O))
args = map(Base.Fix2(toexpr, st), sorted_arguments(O))
expr = Expr(:call, fun)
append!(expr.args, args)
return expr
Expand All @@ -124,7 +124,7 @@ end
function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
out = get(st.rewrites, O, nothing)
out === nothing || return out
args = map(Base.Fix2(toexpr, st), arguments(O))
args = map(Base.Fix2(toexpr, st), sorted_arguments(O))
if length(args) >= 3 && symtype(O) <: Number
x, xs = Iterators.peel(args)
foldl(xs, init=x) do a, b
Expand All @@ -138,7 +138,7 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
end

function function_to_expr(::typeof(^), O, st)
args = arguments(O)
args = sorted_arguments(O)
if length(args) == 2 && args[2] isa Real && args[2] < 0
ex = args[1]
if args[2] == -1
Expand All @@ -151,7 +151,7 @@ function function_to_expr(::typeof(^), O, st)
end

function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
args = arguments(O)
args = sorted_arguments(O)
:($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
end

Expand Down Expand Up @@ -183,7 +183,7 @@ function toexpr(O, st)
return expr′
else
!iscall(O) && return O
args = arguments(O)
args = sorted_arguments(O)
return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...)
end
end
Expand Down Expand Up @@ -693,7 +693,7 @@ end
function _cse!(mem, expr)
iscall(expr) || return expr
op = _cse!(mem, operation(expr))
args = map(Base.Fix1(_cse!, mem), arguments(expr))
args = map(Base.Fix1(_cse!, mem), sorted_arguments(expr))
t = maketerm(typeof(expr), op, args, nothing)

v, dict = mem
Expand All @@ -716,7 +716,7 @@ end

function _cse(exprs::AbstractArray)
letblock = cse(Term{Any}(tuple, vec(exprs)))
letblock.pairs, reshape(arguments(letblock.body), size(exprs))
letblock.pairs, reshape(sorted_arguments(letblock.body), size(exprs))
end

function cse(x::MakeArray)
Expand Down Expand Up @@ -744,7 +744,7 @@ end
function cse_state!(state, t)
!iscall(t) && return t
state[t] = Base.get(state, t, 0) + 1
foreach(x->cse_state!(state, x), unsorted_arguments(t))
foreach(x->cse_state!(state, x), arguments(t))
end

function cse_block!(assignments, counter, names, name, state, x)
Expand All @@ -759,7 +759,7 @@ function cse_block!(assignments, counter, names, name, state, x)
return sym
end
elseif iscall(x)
args = map(a->cse_block!(assignments, counter, names, name, state,a), unsorted_arguments(x))
args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x))
if isterm(x)
return term(operation(x), args...)
else
Expand Down
2 changes: 1 addition & 1 deletion src/inspect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
end

function AbstractTrees.children(x::Symbolic)
iscall(x) ? arguments(x) : isexpr(x) ? children(x) : ()
iscall(x) ? sorted_arguments(x) : isexpr(x) ? children(x) : ()
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function matcher(segment::Segment)
end

function term_matcher(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
matchers = (matcher(operation(term)), map(matcher, sorted_arguments(term))...,)
function term_matcher(success, data, bindings)

!islist(data) && return nothing
Expand Down
8 changes: 4 additions & 4 deletions src/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function get_degrees(expr)
((Symbol(expr),) => 1,)
elseif iscall(expr)
op = operation(expr)
args = arguments(expr)
args = sorted_arguments(expr)
if operation(expr) == (^) && args[2] isa Number
return map(get_degrees(args[1])) do (base, pow)
(base => pow * args[2])
Expand All @@ -35,7 +35,7 @@ function get_degrees(expr)
_, idx = findmax(x->sum(last.(x), init=0), ds)
return ds[idx]
elseif operation(expr) == (getindex)
args = arguments(expr)
args = sorted_arguments(expr)
return ((Symbol.(args)...,) => 1,)
else
return ((Symbol("zzzzzzz", hash(expr)),) => 1,)
Expand All @@ -62,7 +62,7 @@ function lexlt(degs1, degs2)
return false # they are equal
end

_arglen(a) = iscall(a) ? length(unsorted_arguments(a)) : 0
_arglen(a) = iscall(a) ? length(arguments(a)) : 0

function <(a::Tuple, b::Tuple)
for (x, y) in zip(a, b)
Expand All @@ -81,7 +81,7 @@ function <ₑ(a::BasicSymbolic, b::BasicSymbolic)
bw = monomial_lt(db, da)
if fw === bw && !isequal(a, b)
if _arglen(a) == _arglen(b)
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
return (operation(a), sorted_arguments(a)...,) <ₑ (operation(b), sorted_arguments(b)...,)
else
return _arglen(a) < _arglen(b)
end
Expand Down
16 changes: 8 additions & 8 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
end

op = operation(x)
args = arguments(x)
args = sorted_arguments(x)

local_polyize(y) = polyize(y, pvar2sym, sym2term, vtype, pow, Fs, recurse)

Expand Down Expand Up @@ -343,7 +343,7 @@ end

function add_with_div(x, flatten=true)
(!iscall(x) || operation(x) != (+)) && return x
aa = unsorted_arguments(x)
aa = arguments(x)
!any(a->isdiv(a), aa) && return x # no rewrite necessary

divs = filter(a->isdiv(a), aa)
Expand Down Expand Up @@ -381,16 +381,16 @@ end

function needs_div_rules(x)
(isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) ||
(iscall(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) ||
(iscall(x) && any(needs_div_rules, unsorted_arguments(x)))
(iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) ||
(iscall(x) && any(needs_div_rules, arguments(x)))
end

function has_div(x)
return isdiv(x) || (iscall(x) && any(has_div, unsorted_arguments(x)))
return isdiv(x) || (iscall(x) && any(has_div, arguments(x)))
end

flatten_pows(xs) = map(xs) do x
ispow(x) ? Iterators.repeated(arguments(x)...) : (x,)
ispow(x) ? Iterators.repeated(sorted_arguments(x)...) : (x,)
end |> Iterators.flatten |> a->collect(Any,a)

coefftype(x::PolyForm) = coefftype(x.p)
Expand All @@ -414,8 +414,8 @@ Has optimized processes for `Mul` and `Pow` terms.
function quick_cancel(d)
if ispow(d) && isdiv(d.base)
return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp))
elseif ismul(d) && any(isdiv, unsorted_arguments(d))
return prod(unsorted_arguments(d))
elseif ismul(d) && any(isdiv, arguments(d))
return prod(arguments(d))
elseif isdiv(d)
num, den = quick_cancel(d.num, d.den)
return Div(num, den)
Expand Down
8 changes: 4 additions & 4 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module Rewriters
using SymbolicUtils: @timer
using TermInterface

import SymbolicUtils: iscall, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype
import SymbolicUtils: iscall, operation, arguments, sorted_arguments, metadata, node_count, _promote_symtype
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough

# Cache of printed rules to speed up @timer
Expand Down Expand Up @@ -205,7 +205,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}

if iscall(x)
x = p.maketerm(typeof(x), operation(x), map(PassThrough(p),
unsorted_arguments(x)), metadata(x))
arguments(x)), metadata(x))
end

return ord === :post ? p.rw(x) : x
Expand All @@ -221,14 +221,14 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
x = p.rw(x)
end
if iscall(x)
_args = map(arguments(x)) do arg
_args = map(sorted_arguments(x)) do arg
if node_count(arg) > p.thread_cutoff
Threads.@spawn p(arg)
else
p(arg)
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, sorted_arguments(x))
t = p.maketerm(typeof(x), operation(x), args, metadata(x))
end
return ord === :post ? p.rw(t) : t
Expand Down
4 changes: 2 additions & 2 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ getdepth(r::Rule) = r.depth

function rule_depth(rule, d=0, maxdepth=0)
if iscall(rule)
maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in arguments(rule)), init=1)
maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in sorted_arguments(rule)), init=1)
elseif rule isa Slot || rule isa Segment
maxdepth = max(d, maxdepth)
end
Expand Down Expand Up @@ -399,7 +399,7 @@ function (acr::ACRule)(term)
end

T = symtype(term)
args = unsorted_arguments(term)
args = arguments(term)

itr = acr.sets(eachindex(args), acr.arity)

Expand Down
2 changes: 1 addition & 1 deletion src/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ end

has_operation(x, op) = (iscall(x) && (operation(x) == op ||
any(a->has_operation(a, op),
unsorted_arguments(x))))
arguments(x))))

Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...)
4 changes: 2 additions & 2 deletions src/substitute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ function substitute(expr, dict; fold=true)
op = substitute(operation(expr), dict; fold=fold)
if fold
canfold = !(op isa Symbolic)
args = map(unsorted_arguments(expr)) do x
args = map(arguments(expr)) do x
x′ = substitute(x, dict; fold=fold)
canfold = canfold && !(x′ isa Symbolic)
x′
end
canfold && return op(args...)
args
else
args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr))
args = map(x->substitute(x, dict, fold=fold), arguments(expr))
end

maketerm(typeof(expr),
Expand Down
23 changes: 11 additions & 12 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ end

@inline head(x::BasicSymbolic) = operation(x)

function arguments(x::BasicSymbolic)
args = unsorted_arguments(x)
function sorted_arguments(x::BasicSymbolic)
args = arguments(x)
@compactified x::BasicSymbolic begin
Add => @goto ADD
Mul => @goto MUL
Expand All @@ -148,9 +148,8 @@ function arguments(x::BasicSymbolic)
return args
end

unsorted_arguments(x) = arguments(x)
children(x::BasicSymbolic) = arguments(x)
function unsorted_arguments(x::BasicSymbolic)
function arguments(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Term => return x.arguments
Add => @goto ADDMUL
Expand Down Expand Up @@ -254,8 +253,8 @@ function _isequal(a, b, E)
elseif E === POW
isequal(a.exp, b.exp) && isequal(a.base, b.base)
elseif E === TERM
a1 = arguments(a)
a2 = arguments(b)
a1 = sorted_arguments(a)
a2 = sorted_arguments(b)
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
else
error_on_type()
Expand Down Expand Up @@ -296,7 +295,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
!iszero(h) && return h
op = operation(s)
oph = op isa Function ? nameof(op) : op
h′ = hashvec(arguments(s), hash(oph, salt))
h′ = hashvec(sorted_arguments(s), hash(oph, salt))
s.hash[] = h′
return h′
else
Expand Down Expand Up @@ -426,7 +425,7 @@ end

@inline function numerators(x)
isdiv(x) && return numerators(x.num)
iscall(x) && operation(x) === (*) ? arguments(x) : Any[x]
iscall(x) && operation(x) === (*) ? sorted_arguments(x) : Any[x]
end

@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1]
Expand Down Expand Up @@ -545,7 +544,7 @@ function unflatten(t::Symbolic{T}) where{T}
if iscall(t)
f = operation(t)
if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops
a = arguments(t)
a = sorted_arguments(t)
return foldl((x,y) -> Term{T}(f, Any[x, y]), a)
end
end
Expand Down Expand Up @@ -662,7 +661,7 @@ const show_simplified = Ref(false)
isnegative(t::Real) = t < 0
function isnegative(t)
if iscall(t) && operation(t) === (*)
coeff = first(arguments(t))
coeff = first(sorted_arguments(t))
return isnegative(coeff)
end
return false
Expand Down Expand Up @@ -694,7 +693,7 @@ end
function remove_minus(t)
!iscall(t) && return -t
@assert operation(t) == (*)
args = arguments(t)
args = sorted_arguments(t)
@assert args[1] < 0
Any[-args[1], args[2:end]...]
end
Expand Down Expand Up @@ -806,7 +805,7 @@ function show_term(io::IO, t)
end

f = operation(t)
args = arguments(t)
args = sorted_arguments(t)
if symtype(t) <: LiteralReal
show_call(io, f, args)
elseif f === (+)
Expand Down
Loading

0 comments on commit e9ebd8f

Please sign in to comment.