Skip to content

Commit

Permalink
fix: compare metadata of entire expression tree in hashconsing
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 11, 2025
1 parent 0e8ecbc commit 9020755
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 12 deletions.
61 changes: 49 additions & 12 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@ Base.isequal(x, ::Symbolic) = false
Base.isequal(::Symbolic, ::Missing) = false
Base.isequal(::Missing, ::Symbolic) = false
Base.isequal(::Symbolic, ::Symbolic) = false
coeff_isequal(a, b) = isequal(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
function _allarequal(xs, ys)::Bool
coeff_isequal(a, b; comparator = isequal) = comparator(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
function _allarequal(xs, ys; comparator = isequal)::Bool
N = length(xs)
length(ys) == N || return false
for n = 1:N
isequal(xs[n], ys[n]) || return false
comparator(xs[n], ys[n]) || return false
end
return true
end
Expand All @@ -258,19 +258,19 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
T === S || return false
return _isequal(a, b, E)::Bool
end
function _isequal(a, b, E)
function _isequal(a, b, E; comparator = isequal)
if E === SYM
nameof(a) === nameof(b)
elseif E === ADD || E === MUL
coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)
coeff_isequal(a.coeff, b.coeff; comparator) && comparator(a.dict, b.dict)
elseif E === DIV
isequal(a.num, b.num) && isequal(a.den, b.den)
comparator(a.num, b.num) && comparator(a.den, b.den)
elseif E === POW
isequal(a.exp, b.exp) && isequal(a.base, b.base)
comparator(a.exp, b.exp) && comparator(a.base, b.base)
elseif E === TERM
a1 = arguments(a)
a2 = arguments(b)
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
comparator(operation(a), operation(b)) && _allarequal(a1, a2; comparator)
else
error_on_type()
end
Expand All @@ -292,8 +292,14 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
function.
"""
function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool
isequal(a, b) && isequal_with_metadata(metadata(a), metadata(b))
function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool where {T, S}
a === b && return true

E = exprtype(a)
E === exprtype(b) || return false

T === S || return false
_isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false
end

"""
Expand All @@ -303,9 +309,9 @@ Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively
`isequal_with_metadata` to ensure symbolic variables in the metadata also have equal
metadata.
"""
function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{AbstractDict, NamedTuple})
function isequal_with_metadata(a::NamedTuple, b::NamedTuple)
a === b && return true
typeof(a) == typeof(b) || return false
length(a) == length(b) || return false

for (k, v) in pairs(a)
haskey(b, k) || return false
Expand All @@ -320,6 +326,36 @@ function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{Abst
return true
end

function isequal_with_metadata(a::AbstractDict, b::AbstractDict)
a === b && return true
typeof(a) == typeof(b) || return false
length(a) == length(b) || return false

akeys = collect(keys(a))
avisited = falses(length(akeys))
bkeys = collect(keys(b))
bvisited = falses(length(bkeys))

for k in akeys
idx = findfirst(eachindex(bkeys)) do i
!bvisited[i] && isequal_with_metadata(k, bkeys[i])
end
idx === nothing && return false
bvisited[idx] = true
isequal_with_metadata(a[k], b[bkeys[idx]]) || return false
end
for (j, k) in enumerate(bkeys)
bvisited[j] && continue
idx = findfirst(eachindex(akeys)) do i
!avisited[i] && isequal_with_metadata(k, akeys[i])
end
idx === nothing && return false
avisited[idx] = true
isequal_with_metadata(b[k], a[akeys[idx]]) || return false
end
return true
end

"""
$(TYPEDSIGNATURES)
Expand All @@ -341,6 +377,7 @@ Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each
This is to ensure true equality of any symbolic elements, if present.
"""
function isequal_with_metadata(a::Union{AbstractArray, Tuple}, b::Union{AbstractArray, Tuple})
a === b && return true
typeof(a) == typeof(b) || return false
if a isa AbstractArray
size(a) == size(b) || return false
Expand Down
9 changes: 9 additions & 0 deletions test/hash_consing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,12 @@ end
@test metadata(metadata(a1)[Int]) === nothing
@test metadata(metadata(a2)[Int])[Int] == 3
end

@testset "Compare metadata of expression tree" begin
@syms a b
aa = setmetadata(a, Int, b)
@test aa !== a
@test isequal(a, aa)
@test !SymbolicUtils.isequal_with_metadata(a, aa)
@test !SymbolicUtils.isequal_with_metadata(2a, 2aa)
end

0 comments on commit 9020755

Please sign in to comment.