Skip to content

Commit

Permalink
Define get_dict function
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenszhu committed Sep 18, 2024
1 parent 206d21c commit 53ca3b8
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 25 deletions.
4 changes: 2 additions & 2 deletions src/inspect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
string(x.impl.val)
elseif isadd(x)
string(exprtype(x),
(scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict)))
(scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in get_dict(x))))
elseif ismul(x)
string(exprtype(x),
(scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict)))
(scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in get_dict(x))))
elseif isdiv(x) || ispow(x)
string(exprtype(x))
else
Expand Down
8 changes: 4 additions & 4 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,12 @@ end
# mul, pow case
function quick_mulpow(x, y)
y.impl.exp isa Number || return (x, y)
if haskey(x.impl.dict, y.impl.base)
d = copy(x.impl.dict)
if x.impl.dict[y.impl.base] > y.impl.exp
if haskey(get_dict(x), y.impl.base)
d = copy(get_dict(x))
if get_dict(x)[y.impl.base] > y.impl.exp
d[y.impl.base] -= y.impl.exp
den = 1
elseif x.impl.dict[y.impl.base] == y.impl.exp
elseif get_dict(x)[y.impl.base] == y.impl.exp
delete!(d, y.impl.base)
den = 1
else
Expand Down
38 changes: 21 additions & 17 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ function get_coeff(x::BasicSymbolic)
x.impl.coeff
end

function get_dict(x::BasicSymbolic)
x.impl.dict
end

# Same but different error messages
@noinline error_on_type() = error("Internal error: unreachable reached!")
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
Expand Down Expand Up @@ -297,7 +301,7 @@ function _isequal(a, b, E)
if E === SYM
nameof(a) === nameof(b)
elseif E === ADD || E === MUL
coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(a.impl.dict, b.impl.dict)
coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b))
elseif E === DIV
isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den)
elseif E === POW
Expand Down Expand Up @@ -341,7 +345,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
h = s.hash[]
!iszero(h) && return h
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
h′ = hash(hashoffset, hash(get_coeff(s), hash(s.impl.dict, salt)))
h′ = hash(hashoffset, hash(get_coeff(s), hash(get_dict(s), salt)))
s.hash[] = h′
return h′
elseif E === DIV
Expand Down Expand Up @@ -461,7 +465,7 @@ function maybe_intcoeff(x)
if ismul(x)
coeff = get_coeff(x)
if coeff isa Rational && isone(denominator(coeff))
_Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata)
_Mul(symtype(x), coeff.num, get_dict(x); metadata = x.metadata)
else
x
end
Expand Down Expand Up @@ -542,7 +546,7 @@ function toterm(t::BasicSymbolic{T}) where {T}
elseif E === ADD || E === MUL
args = BasicSymbolic[]
push!(args, get_coeff(t))
for (k, coeff) in t.impl.dict
for (k, coeff) in get_dict(t)
push!(
args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k]))
end
Expand All @@ -567,15 +571,15 @@ function makeadd(sign, coeff, xs...)
for x in xs
if isadd(x)
coeff += get_coeff(x)
_merge!(+, d, x.impl.dict, filter = _iszero)
_merge!(+, d, get_dict(x), filter = _iszero)
continue
end
if x isa Number
coeff += x
continue
end
if ismul(x)
k = _Mul(symtype(x), 1, x.impl.dict)
k = _Mul(symtype(x), 1, get_dict(x))
v = sign * get_coeff(x) + get(d, k, 0)
else
k = x
Expand All @@ -598,7 +602,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}())
coeff *= x
elseif ismul(x)
coeff *= get_coeff(x)
_merge!(+, d, x.impl.dict, filter = _iszero)
_merge!(+, d, get_dict(x), filter = _iszero)
else
v = 1 + get(d, x, 0)
if _iszero(v)
Expand Down Expand Up @@ -1223,10 +1227,10 @@ function +(a::SN, b::SN)
!issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata
if isadd(a) && isadd(b)
return _Add(
add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero))
elseif isadd(a)
coeff, dict = makeadd(1, 0, b)
return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero))
return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, get_dict(a), dict, filter = _iszero))
elseif isadd(b)
return b + a
end
Expand All @@ -1240,7 +1244,7 @@ function +(a::Number, b::SN)
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
iszero(a) && return b
if isadd(b)
_Add(add_t(a, b), a + get_coeff(b), b.impl.dict)
_Add(add_t(a, b), a + get_coeff(b), get_dict(b))
else
_Add(add_t(a, b), makeadd(1, a, b)...)
end
Expand All @@ -1258,15 +1262,15 @@ function -(a::SN)
return term(-, a)
end
if isadd(a)
_Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, a.impl.dict))
_Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, get_dict(a)))
else
_Add(sub_t(a), makeadd(-1, 0, a)...)
end
end
function -(a::SN, b::SN)
(!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b)
if isadd(a) && isadd(b)
_Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero))
_Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, get_dict(a), get_dict(b), filter = _iszero))
else
a + (-b)
end
Expand Down Expand Up @@ -1294,16 +1298,16 @@ function *(a::SN, b::SN)
_Div(a * b.impl.num, b.impl.den)
elseif ismul(a) && ismul(b)
_Mul(mul_t(a, b), get_coeff(a) * get_coeff(b),
_merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
_merge(+, get_dict(a), get_dict(b), filter = _iszero))
elseif ismul(a) && ispow(b)
if b.impl.exp isa Number
_Mul(mul_t(a, b),
get_coeff(a),
_merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp),
_merge(+, get_dict(a), Base.ImmutableDict(b.impl.base => b.impl.exp),
filter = _iszero))
else
_Mul(mul_t(a, b), get_coeff(a),
_merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero))
_merge(+, get_dict(a), Base.ImmutableDict(b => 1), filter = _iszero))
end
elseif ispow(a) && ismul(b)
b * a
Expand All @@ -1326,7 +1330,7 @@ function *(a::Number, b::SN)
# -1(a+b) -> -a - b
T = promote_symtype(+, typeof(a), symtype(b))
_Add(T, get_coeff(b) * a,
Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict))
Dict{BasicSymbolic, Any}(k => v * a for (k, v) in get_dict(b)))
else
_Mul(mul_t(a, b), makemul(a, b)...)
end
Expand All @@ -1352,7 +1356,7 @@ function ^(a::SN, b)
elseif ismul(a) && b isa Number
coeff = unstable_pow(get_coeff(a), b)
_Mul(promote_symtype(^, symtype(a), symtype(b)),
coeff, mapvalues((k, v) -> b * v, a.impl.dict))
coeff, mapvalues((k, v) -> b * v, get_dict(a)))
else
_Pow(a, b)
end
Expand Down
4 changes: 2 additions & 2 deletions test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm,
BasicSymbolic, term, get_name, get_coeff
BasicSymbolic, term, get_name, get_coeff, get_dict
using SymbolicUtils
using IfElse: ifelse
using Setfield
Expand Down Expand Up @@ -234,7 +234,7 @@ end

@testset "maketerm" begin
@syms a b c
@test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).impl.dict, Dict(a=>1,b=>1,c=>1))
@test isequal(get_dict(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing)), Dict(a=>1,b=>1,c=>1))
@test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b)

# test that maketerm doesn't hard-code BasicSymbolic subtype
Expand Down

0 comments on commit 53ca3b8

Please sign in to comment.