Skip to content

Commit

Permalink
Merge pull request #533 from chriselrod/optimizeisequal
Browse files Browse the repository at this point in the history
optimize `isequal`
  • Loading branch information
shashi authored Jul 17, 2023
2 parents e3eba85 + 5b4f2b2 commit 3dc99d4
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,25 @@ Base.isequal(::Symbolic, x) = false
Base.isequal(x, ::Symbolic) = false
Base.isequal(::Symbolic, ::Symbolic) = false
coeff_isequal(a, b) = isequal(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
function _allarequal(xs, ys)::Bool
N = length(xs)
length(ys) == N || return false
for n = 1:N
isequal(xs[n], ys[n]) || return false
end
return true
end

function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
a === b && return true

E = exprtype(a)
E === exprtype(b) || return false

T === S || return false

return _isequal(a, b, E)::Bool
end
function _isequal(a, b, E)
if E === SYM
nameof(a) === nameof(b)
elseif E === ADD || E === MUL
Expand All @@ -208,9 +219,7 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
elseif E === TERM
a1 = arguments(a)
a2 = arguments(b)
isequal(operation(a), operation(b)) &&
length(a1) == length(a2) &&
all(isequal(l, r) for (l, r) in zip(a1, a2))
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
else
error_on_type()
end
Expand All @@ -223,13 +232,12 @@ Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymboli

## This is much faster than hash of an array of Any
hashvec(xs, z) = foldr(hash, xs, init=z)

const SYM_SALT = 0x4de7d7c66d41da43 % UInt
const ADD_SALT = 0xaddaddaddaddadda % UInt
const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt
const DIV_SALT = 0x334b218e73bbba53 % UInt
const POW_SALT = 0x2b55b97a6efb080c % UInt
function Base.hash(s::BasicSymbolic, salt::UInt)
function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
E = exprtype(s)
if E === SYM
hash(nameof(s), salt SYM_SALT)
Expand Down

0 comments on commit 3dc99d4

Please sign in to comment.