Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

constructor-level simplification #154

Merged
merged 29 commits into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
313bc7f
begin constructor-level simplification
shashi Jan 4, 2021
d1d7da3
fixes
shashi Jan 4, 2021
c99d25a
div
shashi Jan 4, 2021
e23eba9
whitespace
shashi Jan 4, 2021
a8ee55b
Add `Number - SN` overload and eagerly evaluate coeff when adding a n…
YingboMa Jan 4, 2021
febf9fe
enable tree interface on the fast terms
shashi Jan 5, 2021
25e735f
Fix stuff
shashi Jan 5, 2021
b7928ec
Bring back better printing
shashi Jan 5, 2021
fe5a5f6
fixes 3
shashi Jan 5, 2021
f3b67b5
Merge remote-tracking branch 'origin/s/fast-terms' into s/fast-terms
shashi Jan 5, 2021
e22db85
fix tests and fix printing
shashi Jan 5, 2021
7f7e04f
Delete some more printing code
shashi Jan 5, 2021
f442971
fuzz: print problem
shashi Jan 5, 2021
acdf077
fix printing with Rational and Complex
shashi Jan 5, 2021
6ed0e9d
Cache sorted arguments in Add and Mul
shashi Jan 6, 2021
0f89ab1
fix arguments on Mul copy-paste
shashi Jan 6, 2021
7278252
fix (a+b)-a
shashi Jan 6, 2021
2ec353e
updates for MTK
shashi Jan 6, 2021
f621220
add 1-arg *
shashi Jan 6, 2021
20c63c4
show function of Term so that Differential(t) is visible
shashi Jan 6, 2021
acaf8e3
Fix overload ambiguity
YingboMa Jan 6, 2021
35a0bfa
print fix
shashi Jan 6, 2021
dafe2e6
Fix and test `-(::Add)`
YingboMa Jan 6, 2021
e5cb78d
configurable similarterm in Walk
shashi Jan 8, 2021
4247bd1
proper type promotion for Add
shashi Jan 9, 2021
277d6ea
proper type promotion for Mul and Pow
shashi Jan 9, 2021
e2a8e4c
fix up promotion and tests
shashi Jan 9, 2021
e2517f5
Move Add Mul Pow into types.jl
shashi Jan 9, 2021
c62beb7
add some docs
shashi Jan 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ export @syms, term, showraw
# Sym, Term and other types
include("types.jl")

# Add, Mul and Pow
include("fast-terms.jl")

# Methods on symbolic objects
using SpecialFunctions, NaNMath
export cond
Expand Down
278 changes: 278 additions & 0 deletions src/fast-terms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import Base: +, -, *, /, ^

const SN = Symbolic{<:Number}
"""
Add(coeff, dict)

Represents coeff + (key1 * val1) + (key2 * val2) + ...

where coeff and the vals are non-symbolic numbers.

"""
struct Add{X, T, D} <: Symbolic{X}
coeff::T
dict::D
end

function Add(coeff, dict)
if isempty(dict)
return coeff
end
Add{Number, typeof(coeff), typeof(dict)}(coeff,dict)
end

symtype(a::Add{X}) where {X} = X

istree(a::Add) = true

operation(a::Add) = +

arguments(a::Add) = vcat(a.coeff, [v*k for (k,v) in a.dict])

Base.hash(a::Add, u::UInt64) = hash(a.coeff, hash(a.dict, u))

Base.isequal(a::Add, b::Add) = isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)

function Base.show(io::IO, a::Add)
print_coeff = !iszero(a.coeff)
print_coeff && print(io, a.coeff)

for (i, (k, v)) in enumerate(a.dict)
if (i == 1 && print_coeff) || i != 1
print(io, " + ")
end
if isone(v)
print(io, k)
else
print(io, v, k)
end
end
end

"""
make_add_dict(sign, xs...)

Any Muls inside an Add should always have a coeff of 1
and the key (in Add) should instead be used to store the actual coefficient
"""
function make_add_dict(sign, xs...)
d = Dict{Any, Number}()
for x in xs
if x isa Mul
k = Mul(1, x.dict)
v = sign * x.coeff + get(d, k, 0)
else
k = x
v = sign + get(d, x, 0)
end
if iszero(v)
delete!(d, k)
else
d[k] = v
end
end
d
end

+(a::Number, b::SN) = Add(a, make_add_dict(1, b))

+(a::SN, b::Number) = Add(b, make_add_dict(1, a))

function +(a::SN, b::SN)
if a isa Add
return a + Add(0, make_add_dict(1, b))
elseif b isa Add
return b + a
end
Add(0, make_add_dict(1, a, b))
end

+(a::Add, b::Add) = Add(a.coeff + b.coeff, _merge(+, a.dict, b.dict, filter=iszero))

+(a::Number, b::Add) = iszero(a) ? b : Add(a, make_add_dict(1, b))

+(b::Add, a::Number) = iszero(a) ? b : Add(a, make_add_dict(1, b))
YingboMa marked this conversation as resolved.
Show resolved Hide resolved

-(a::Add) = Add(-a.coeff, mapvalues(-, a.dict))

-(a::SN) = Add(0, make_add_dict(-1, a))

-(a::Add, b::Add) = Add(a.coeff - b.coeff, _merge(-, a.dict, b.dict, filter=iszero))

-(a::SN, b::SN) = a + (-b)

"""
Mul(coeff, dict)

Represents coeff * (key1 ^ val1) * (key2 ^ val2) * ....

where coeff is a non-symbolic number.
"""
struct Mul{X, T, D} <: Symbolic{X}
coeff::T
dict::D
end

function Mul(a,b)
isempty(b) && return a
if isone(a) && length(b) == 1
pair = first(b)
if isone(last(pair)) # first value
return first(pair)
else
return Pow(first(pair), last(pair))
end
else
Mul{Number, typeof(a), typeof(b)}(a,b)
end
end

symtype(a::Mul{X}) where {X} = X

istree(a::Mul) = true

operation(a::Mul) = *

arguments(a::Mul) = vcat(a.coeff, [k^v for (k,v) in a.dict])

Base.hash(m::Mul, u::UInt64) = hash(m.coeff, hash(m.dict, u))

Base.isequal(a::Mul, b::Mul) = isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)

function Base.show(io::IO, a::Mul)
print_coeff = !isone(a.coeff)
print_coeff && print(io, a.coeff)

for (i, (k, v)) in enumerate(a.dict)
if (i == 1 && print_coeff) || i != 1
print(io, " * ")
end
if isone(v)
if !(k isa Sym)
print(io, "(", k, ")")
else
print(io, k)
end
else
if !(k isa Sym)
print(io, "(", k, ")^", v)
else
print(io, k, "^", v)
end
end
end
end

"""
make_mul_dict(xs...)
"""
function make_mul_coeff_dict(sign, coeff, xs...; d=Dict{Any, Number}())
for x in xs
if x isa Pow
d[x.base] = sign * x.exp + get(d, x.base, 0)
elseif x isa Mul
coeff *= x.coeff
dict = isone(sign) ? x.dict : mapvalues((_,v)->sign*v, x.dict)
d = _merge(+, d, dict, filter=iszero)
else
k = x
v = sign + get(d, x, 0)
if iszero(v)
delete!(d, k)
else
d[k] = v
end
end
end
coeff, d
end

*(a::SN, b::SN) = Mul(make_mul_coeff_dict(1, 1, a, b)...)

*(a::Mul, b::Mul) = Mul(a.coeff * b.coeff, _merge(+, a.dict, b.dict, filter=iszero))

*(a::Number, b::SN) = iszero(a) ? a : isone(a) ? b : Mul(make_mul_coeff_dict(1,a, b)...)

*(b::SN, a::Number) = iszero(a) ? a : isone(a) ? b : Mul(make_mul_coeff_dict(1,a, b)...)

function /(a::Union{SN,Number}, b::SN)
a * Mul(make_mul_coeff_dict(-1, 1, b)...)
end

\(a::SN, b::SN) = b / a

/(a::SN, b::Number) = inv(b) * a

"""
Pow(base, exp)

Represents base^exp, a lighter version of Mul(1, Dict(base=>exp))
"""
struct Pow{X, B, E} <: Symbolic{X}
base::B
exp::E
end

function Pow(a,b)
iszero(b) && return 1
isone(b) && return a
Pow{Number, typeof(a), typeof(b)}(a,b)
end

symtype(a::Pow{X}) where {X} = X

istree(a::Pow) = true

operation(a::Pow) = ^

arguments(a::Pow) = [a.base, a.exp]

Base.hash(p::Pow, u::UInt) = hash(p.exp, hash(p.base, u))

Base.isequal(p::Pow, b::Pow) = isequal(p.base, b.base) && isequal(p.exp, b.exp)

function Base.show(io::IO, p::Pow)
k, v = p.base, p.exp
if !(k isa Sym)
print(io, "(", k, ")^", v)
else
print(io, k, "^", v)
end
end

^(a::SN, b) = Pow(a, b)

function ^(a::Mul, b::Number)
Mul(a.coeff ^ b, mapvalues((k, v) -> b*v, a.dict))
end

function *(a::Mul, b::Pow)
Mul(a.coeff, _merge(+, a.dict, Base.ImmutableDict(b.base=>b.exp), filter=iszero))
end

*(a::Pow, b::Mul) = b * a

function _merge(f, d, others...; filter=x->false)
acc = copy(d)
for other in others
for (k, v) in other
if haskey(acc, k)
v = f(acc[k], v)
end
if filter(v)
delete!(acc, k)
else
acc[k] = v
end
end
end
acc
end

function mapvalues(f, d1::Dict)
d = copy(d1)
for (k, v) in d
d[k] = f(k, v)
end
d
end
23 changes: 2 additions & 21 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx,
besselj1, bessely0, bessely1, besselj, bessely, besseli,
besselk, hankelh1, hankelh2, polygamma, beta, logbeta

const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
const monadic = [deg2rad, rad2deg, transpose, conj, asind, log1p, acsch,
acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh,
abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, acotd,
asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10,
Expand All @@ -14,7 +14,7 @@ const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch,
trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi,
airybiprime, besselj0, besselj1, bessely0, bessely1]

const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign,
const diadic = [max, min, hypot, atan, mod, rem, copysign,
besselj, bessely, besseli, besselk, hankelh1, hankelh2,
polygamma, beta, logbeta]

Expand Down Expand Up @@ -93,25 +93,6 @@ rec_promote_symtype(f, x) = promote_symtype(f, x)
rec_promote_symtype(f, x,y) = promote_symtype(f, x,y)
rec_promote_symtype(f, x,y,z...) = rec_promote_symtype(f, promote_symtype(f, x,y), z...)

# Variadic methods
for f in [+, *]

@eval (::$(typeof(f)))(x::Symbolic) = x

# single arg
@eval function (::$(typeof(f)))(x::Symbolic, w::Number...)
term($f, x,w...,
type=rec_promote_symtype($f, map(symtype, (x,w...))...))
end
@eval function (::$(typeof(f)))(x::Number, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
@eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w::Number...)
term($f, x, y, w...,
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
end
end

Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
Expand Down