From 0e8ecbc6b380c49e9b2a575aabaa8f4c2820e7ba Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Jan 2025 20:32:42 +0530 Subject: [PATCH 1/2] fix: fix hashconsing not comparing metadata of symbolics inside metadata --- src/types.jl | 57 +++++++++++++++++++++++++++++++++++++++++++- test/hash_consing.jl | 13 +++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index 2572bd8b..a71fefa8 100644 --- a/src/types.jl +++ b/src/types.jl @@ -293,7 +293,62 @@ downstream packages like `ModelingToolkit.jl`, hence the need for this separate function. """ function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool - isequal(a, b) && isequal(metadata(a), metadata(b)) + isequal(a, b) && isequal_with_metadata(metadata(a), metadata(b)) +end + +""" + $(TYPEDSIGNATURES) + +Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively calling +`isequal_with_metadata` to ensure symbolic variables in the metadata also have equal +metadata. +""" +function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{AbstractDict, NamedTuple}) + typeof(a) == typeof(b) || return false + length(a) == length(b) || return false + + for (k, v) in pairs(a) + haskey(b, k) || return false + isequal_with_metadata(v, b[k]) || return false + end + + for (k, v) in pairs(b) + haskey(a, k) || return false + isequal_with_metadata(v, a[k]) || return false + end + + return true +end + +""" + $(TYPEDSIGNATURES) + +Fallback method which uses `isequal`. +""" +isequal_with_metadata(a, b) = isequal(a, b) + +""" + $(TYPEDSIGNATURES) + +Specialized methods to check if two ranges are equal without comparing each element. +""" +isequal_with_metadata(a::AbstractRange, b::AbstractRange) = isequal(a, b) + +""" + $(TYPEDSIGNATURES) + +Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each element. +This is to ensure true equality of any symbolic elements, if present. +""" +function isequal_with_metadata(a::Union{AbstractArray, Tuple}, b::Union{AbstractArray, Tuple}) + typeof(a) == typeof(b) || return false + if a isa AbstractArray + size(a) == size(b) || return false + end # otherwise they're tuples and type equality also checks length equality + for (x, y) in zip(a, b) + isequal_with_metadata(x, y) || return false + end + return true end Base.one( s::Symbolic) = one( symtype(s)) diff --git a/test/hash_consing.jl b/test/hash_consing.jl index aa6aa7a9..7e23f94c 100644 --- a/test/hash_consing.jl +++ b/test/hash_consing.jl @@ -1,5 +1,5 @@ using SymbolicUtils, Test -using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2 +using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2, metadata struct Ctx1 end struct Ctx2 end @@ -108,3 +108,14 @@ end @test hash2(f, u0) != hash2(r, u0) @test f + a !== r + a end + +@testset "Symbolics in metadata" begin + @syms a b + a1 = setmetadata(a, Int, b) + b1 = setmetadata(b, Int, 3) + a2 = setmetadata(a, Int, b1) + @test a1 !== a2 + @test !SymbolicUtils.isequal_with_metadata(a1, a2) + @test metadata(metadata(a1)[Int]) === nothing + @test metadata(metadata(a2)[Int])[Int] == 3 +end From 90207558d8346d729320a3dbd6024dfaca55d7a1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 11 Jan 2025 11:15:44 +0530 Subject: [PATCH 2/2] fix: compare metadata of entire expression tree in hashconsing --- src/types.jl | 61 +++++++++++++++++++++++++++++++++++--------- test/hash_consing.jl | 9 +++++++ 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/src/types.jl b/src/types.jl index a71fefa8..4f186935 100644 --- a/src/types.jl +++ b/src/types.jl @@ -239,12 +239,12 @@ Base.isequal(x, ::Symbolic) = false Base.isequal(::Symbolic, ::Missing) = false Base.isequal(::Missing, ::Symbolic) = false Base.isequal(::Symbolic, ::Symbolic) = false -coeff_isequal(a, b) = isequal(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b)) -function _allarequal(xs, ys)::Bool +coeff_isequal(a, b; comparator = isequal) = comparator(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b)) +function _allarequal(xs, ys; comparator = isequal)::Bool N = length(xs) length(ys) == N || return false for n = 1:N - isequal(xs[n], ys[n]) || return false + comparator(xs[n], ys[n]) || return false end return true end @@ -258,19 +258,19 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S} T === S || return false return _isequal(a, b, E)::Bool end -function _isequal(a, b, E) +function _isequal(a, b, E; comparator = isequal) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict) + coeff_isequal(a.coeff, b.coeff; comparator) && comparator(a.dict, b.dict) elseif E === DIV - isequal(a.num, b.num) && isequal(a.den, b.den) + comparator(a.num, b.num) && comparator(a.den, b.den) elseif E === POW - isequal(a.exp, b.exp) && isequal(a.base, b.base) + comparator(a.exp, b.exp) && comparator(a.base, b.base) elseif E === TERM a1 = arguments(a) a2 = arguments(b) - isequal(operation(a), operation(b)) && _allarequal(a1, a2) + comparator(operation(a), operation(b)) && _allarequal(a1, a2; comparator) else error_on_type() end @@ -292,8 +292,14 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an downstream packages like `ModelingToolkit.jl`, hence the need for this separate function. """ -function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool - isequal(a, b) && isequal_with_metadata(metadata(a), metadata(b)) +function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool where {T, S} + a === b && return true + + E = exprtype(a) + E === exprtype(b) || return false + + T === S || return false + _isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false end """ @@ -303,9 +309,9 @@ Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively `isequal_with_metadata` to ensure symbolic variables in the metadata also have equal metadata. """ -function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{AbstractDict, NamedTuple}) +function isequal_with_metadata(a::NamedTuple, b::NamedTuple) + a === b && return true typeof(a) == typeof(b) || return false - length(a) == length(b) || return false for (k, v) in pairs(a) haskey(b, k) || return false @@ -320,6 +326,36 @@ function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{Abst return true end +function isequal_with_metadata(a::AbstractDict, b::AbstractDict) + a === b && return true + typeof(a) == typeof(b) || return false + length(a) == length(b) || return false + + akeys = collect(keys(a)) + avisited = falses(length(akeys)) + bkeys = collect(keys(b)) + bvisited = falses(length(bkeys)) + + for k in akeys + idx = findfirst(eachindex(bkeys)) do i + !bvisited[i] && isequal_with_metadata(k, bkeys[i]) + end + idx === nothing && return false + bvisited[idx] = true + isequal_with_metadata(a[k], b[bkeys[idx]]) || return false + end + for (j, k) in enumerate(bkeys) + bvisited[j] && continue + idx = findfirst(eachindex(akeys)) do i + !avisited[i] && isequal_with_metadata(k, akeys[i]) + end + idx === nothing && return false + avisited[idx] = true + isequal_with_metadata(b[k], a[akeys[idx]]) || return false + end + return true +end + """ $(TYPEDSIGNATURES) @@ -341,6 +377,7 @@ Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each This is to ensure true equality of any symbolic elements, if present. """ function isequal_with_metadata(a::Union{AbstractArray, Tuple}, b::Union{AbstractArray, Tuple}) + a === b && return true typeof(a) == typeof(b) || return false if a isa AbstractArray size(a) == size(b) || return false diff --git a/test/hash_consing.jl b/test/hash_consing.jl index 7e23f94c..0bea6350 100644 --- a/test/hash_consing.jl +++ b/test/hash_consing.jl @@ -119,3 +119,12 @@ end @test metadata(metadata(a1)[Int]) === nothing @test metadata(metadata(a2)[Int])[Int] == 3 end + +@testset "Compare metadata of expression tree" begin + @syms a b + aa = setmetadata(a, Int, b) + @test aa !== a + @test isequal(a, aa) + @test !SymbolicUtils.isequal_with_metadata(a, aa) + @test !SymbolicUtils.isequal_with_metadata(2a, 2aa) +end