Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix hashconsing not comparing nested metadata and inside expressions #691

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 102 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@ Base.isequal(x, ::Symbolic) = false
Base.isequal(::Symbolic, ::Missing) = false
Base.isequal(::Missing, ::Symbolic) = false
Base.isequal(::Symbolic, ::Symbolic) = false
coeff_isequal(a, b) = isequal(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
function _allarequal(xs, ys)::Bool
coeff_isequal(a, b; comparator = isequal) = comparator(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
function _allarequal(xs, ys; comparator = isequal)::Bool
N = length(xs)
length(ys) == N || return false
for n = 1:N
isequal(xs[n], ys[n]) || return false
comparator(xs[n], ys[n]) || return false
end
return true
end
Expand All @@ -258,19 +258,19 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
T === S || return false
return _isequal(a, b, E)::Bool
end
function _isequal(a, b, E)
function _isequal(a, b, E; comparator = isequal)
if E === SYM
nameof(a) === nameof(b)
elseif E === ADD || E === MUL
coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)
coeff_isequal(a.coeff, b.coeff; comparator) && comparator(a.dict, b.dict)
elseif E === DIV
isequal(a.num, b.num) && isequal(a.den, b.den)
comparator(a.num, b.num) && comparator(a.den, b.den)
elseif E === POW
isequal(a.exp, b.exp) && isequal(a.base, b.base)
comparator(a.exp, b.exp) && comparator(a.base, b.base)
elseif E === TERM
a1 = arguments(a)
a2 = arguments(b)
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
comparator(operation(a), operation(b)) && _allarequal(a1, a2; comparator)
else
error_on_type()
end
Expand All @@ -292,8 +292,100 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
function.
"""
function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool
isequal(a, b) && isequal(metadata(a), metadata(b))
function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool where {T, S}
a === b && return true

E = exprtype(a)
E === exprtype(b) || return false

T === S || return false
_isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false
end

"""
$(TYPEDSIGNATURES)

Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively calling
`isequal_with_metadata` to ensure symbolic variables in the metadata also have equal
metadata.
"""
function isequal_with_metadata(a::NamedTuple, b::NamedTuple)
a === b && return true
typeof(a) == typeof(b) || return false

for (k, v) in pairs(a)
haskey(b, k) || return false
isequal_with_metadata(v, b[k]) || return false
end

for (k, v) in pairs(b)
haskey(a, k) || return false
isequal_with_metadata(v, a[k]) || return false
end

return true
end

function isequal_with_metadata(a::AbstractDict, b::AbstractDict)
a === b && return true
typeof(a) == typeof(b) || return false
length(a) == length(b) || return false

akeys = collect(keys(a))
avisited = falses(length(akeys))
bkeys = collect(keys(b))
bvisited = falses(length(bkeys))

for k in akeys
idx = findfirst(eachindex(bkeys)) do i
!bvisited[i] && isequal_with_metadata(k, bkeys[i])
end
idx === nothing && return false
bvisited[idx] = true
isequal_with_metadata(a[k], b[bkeys[idx]]) || return false
end
for (j, k) in enumerate(bkeys)
bvisited[j] && continue
idx = findfirst(eachindex(akeys)) do i
!avisited[i] && isequal_with_metadata(k, akeys[i])
end
idx === nothing && return false
avisited[idx] = true
isequal_with_metadata(b[k], a[akeys[idx]]) || return false
end
return true
end

"""
$(TYPEDSIGNATURES)

Fallback method which uses `isequal`.
"""
isequal_with_metadata(a, b) = isequal(a, b)

"""
$(TYPEDSIGNATURES)

Specialized methods to check if two ranges are equal without comparing each element.
"""
isequal_with_metadata(a::AbstractRange, b::AbstractRange) = isequal(a, b)

"""
$(TYPEDSIGNATURES)

Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each element.
This is to ensure true equality of any symbolic elements, if present.
"""
function isequal_with_metadata(a::Union{AbstractArray, Tuple}, b::Union{AbstractArray, Tuple})
a === b && return true
typeof(a) == typeof(b) || return false
if a isa AbstractArray
size(a) == size(b) || return false
end # otherwise they're tuples and type equality also checks length equality
for (x, y) in zip(a, b)
isequal_with_metadata(x, y) || return false
end
return true
end

Base.one( s::Symbolic) = one( symtype(s))
Expand Down
22 changes: 21 additions & 1 deletion test/hash_consing.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using SymbolicUtils, Test
using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2
using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2, metadata

struct Ctx1 end
struct Ctx2 end
Expand Down Expand Up @@ -108,3 +108,23 @@ end
@test hash2(f, u0) != hash2(r, u0)
@test f + a !== r + a
end

@testset "Symbolics in metadata" begin
@syms a b
a1 = setmetadata(a, Int, b)
b1 = setmetadata(b, Int, 3)
a2 = setmetadata(a, Int, b1)
@test a1 !== a2
@test !SymbolicUtils.isequal_with_metadata(a1, a2)
@test metadata(metadata(a1)[Int]) === nothing
@test metadata(metadata(a2)[Int])[Int] == 3
end

@testset "Compare metadata of expression tree" begin
@syms a b
aa = setmetadata(a, Int, b)
@test aa !== a
@test isequal(a, aa)
@test !SymbolicUtils.isequal_with_metadata(a, aa)
@test !SymbolicUtils.isequal_with_metadata(2a, 2aa)
end
Loading