forked from SciML/ModelingToolkit.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolyform.jl
81 lines (70 loc) · 1.64 KB
/
polyform.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
using FunctionalCollections
const pdict = PersistentHashMap
const onekey = pdict{Union{}, Union{}}()
# something like
# 1//2 * (x^2)(y^3) + 4//5 * x
# is expressed as:
# PDict(x^2 * y^3 => 1//2, x^1 => 4//5)
# Where x^2 * y^3 is represented as
# PDict(x=>2, y=>3)
#
struct LinearCombination{T<:pdict}
terms::T
end
function _merge(f, d, others...)
acc = d
for other in others
for (k, v) in other
if haskey(acc, k)
acc = assoc(acc, k, f(acc[k], v))
else
acc = assoc(acc, k, v)
end
end
end
acc
end
function constterm(b)
LinearCombination(pdict(onekey => b))
end
function Base.:+(a::LinearCombination, b::LinearCombination)
_merge(+, a.terms, b.terms)
end
function Base.:+(a::LinearCombination, b)
if iszero(b)
return a
else
return a + constterm(b)
end
end
Base.:+(a, b::LinearCombination) = b+a
# Multiply 42*x^2*y^2 and 56*x^3*z
# which are actually:
# pdict(:x=>2, :y=>2)=>42 and
# pdict(:x=>3, :z=>3)=>56
function mul_term((a, ac)::Pair, (b, bc)::Pair)
_merge(+, a, b) => ac * bc
end
function Base.:(*)(a::LinearCombination, b::LinearCombination)
sum(LinearCombination(pdict(mul_term(ta, tb)))
for ta in a.terms, tb in b.terms)
end
function Base.:(*)(a::LinearCombination, b)
if iszero(b)
return zero(b)
elseif isone(b)
return a
else
return a * constterm(b)
end
end
# don't assume commutativity
function Base.:(*)(a, b::LinearCombination)
if iszero(a)
return zero(a)
elseif isone(a)
return b
else
constterm(a) * b
end
end