diff --git a/src/types.jl b/src/types.jl index 898259f4..2572bd8b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -97,6 +97,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple # Call outer constructor because hash consing cannot be applied in inner constructor @compactified obj::BasicSymbolic begin Sym => Sym{T}(nt_new.name; nt_new...) + Term => Term{T}(nt_new.f, nt_new.arguments; nt_new...) + Add => Add(T, nt_new.coeff, nt_new.dict; nt_new...) + Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new...) + Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new...) + Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new...) _ => Unityper.rt_constructor(obj){T}(;nt_new...) end end @@ -298,6 +303,7 @@ Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymboli ## This is much faster than hash of an array of Any hashvec(xs, z) = foldr(hash, xs, init=z) +hashvec2(xs, z) = foldr(hash2, xs, init=z) const SYM_SALT = 0x4de7d7c66d41da43 % UInt const ADD_SALT = 0xaddaddaddaddadda % UInt const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt @@ -344,10 +350,43 @@ objects. Unlike `Base.hash`, which only considers the expression structure, `has includes the metadata and symtype in the hash calculation. This can be beneficial for hash consing, allowing for more effective deduplication of symbolically equivalent expressions with different metadata or symtypes. + +Equivalent numbers of different types, such as `0.5::Float64` and +`(1 // 2)::Rational{Int64}`, have the same default `Base.hash` value. The `hash2` function +distinguishes these by including their numeric types in the hash calculation to ensure that +symbolically equivalent expressions with different numeric types are treated as distinct +objects. """ +hash2(s, salt::UInt) = hash(s, salt) +function hash2(n::T, salt::UInt) where {T <: Number} + hash(T, hash(n, salt)) +end hash2(s::BasicSymbolic) = hash2(s, zero(UInt)) function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T} - hash(metadata(s), hash(T, hash(s, salt))) + E = exprtype(s) + h::UInt = 0 + if E === SYM + h = hash(nameof(s), salt ⊻ SYM_SALT) + elseif E === ADD || E === MUL + hashoffset = isadd(s) ? ADD_SALT : SUB_SALT + hv = Base.hasha_seed + for (k, v) in s.dict + hv ⊻= hash2(k, hash(v)) + end + h = hash(hv, salt) + h = hash(hashoffset, hash2(s.coeff, h)) + elseif E === DIV + h = hash2(s.num, hash2(s.den, salt ⊻ DIV_SALT)) + elseif E === POW + h = hash2(s.exp, hash2(s.base, salt ⊻ POW_SALT)) + elseif E === TERM + op = operation(s) + oph = op isa Function ? nameof(op) : op + h = hashvec2(arguments(s), hash(oph, salt)) + else + error_on_type() + end + hash(metadata(s), hash(T, h)) end ### @@ -395,7 +434,8 @@ function Term{T}(f, args; kw...) where T args = convert(Vector{Any}, args) end - Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...) + s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...) + BasicSymbolic(s) end function Term(f, args; metadata=NO_METADATA) @@ -415,7 +455,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T end end - Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + BasicSymbolic(s) end function Mul(T, a, b; metadata=NO_METADATA, kw...) @@ -430,7 +471,8 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...) else coeff = a dict = b - Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + BasicSymbolic(s) end end @@ -461,7 +503,7 @@ function maybe_intcoeff(x) end end -function Div{T}(n, d, simplified=false; metadata=nothing) where {T} +function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T} if T<:Number && !(T<:SafeReal) n, d = quick_cancel(n, d) end @@ -495,7 +537,8 @@ function Div{T}(n, d, simplified=false; metadata=nothing) where {T} end end - Div{T}(; num=n, den=d, simplified, arguments=[], metadata) + s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata) + BasicSymbolic(s) end function Div(n,d, simplified=false; kw...) @@ -509,14 +552,15 @@ end @inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1] -function Pow{T}(a, b; metadata=NO_METADATA) where {T} +function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T} _iszero(b) && return 1 _isone(b) && return a - Pow{T}(; base=a, exp=b, arguments=[], metadata) + s = Pow{T}(; base=a, exp=b, arguments=[], metadata) + BasicSymbolic(s) end -function Pow(a, b; metadata=NO_METADATA) - Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)..., metadata=metadata) +function Pow(a, b; metadata = NO_METADATA, kwargs...) + Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)...; metadata, kwargs...) end function toterm(t::BasicSymbolic{T}) where T diff --git a/test/hash_consing.jl b/test/hash_consing.jl index aaf97997..aa6aa7a9 100644 --- a/test/hash_consing.jl +++ b/test/hash_consing.jl @@ -1,4 +1,5 @@ using SymbolicUtils, Test +using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2 struct Ctx1 end struct Ctx2 end @@ -24,3 +25,86 @@ struct Ctx2 end xm3 = setmetadata(x1, Ctx2, "meta_2") @test xm1 !== xm3 end + +@syms a b c + +@testset "Term" begin + t1 = sin(a) + t2 = sin(a) + @test t1 === t2 + t3 = Term(identity,[a]) + t4 = Term(identity,[a]) + @test t3 === t4 + t5 = Term{Int}(identity,[a]) + @test t3 !== t5 + tm1 = setmetadata(t1, Ctx1, "meta_1") + @test t1 !== tm1 +end + +@testset "Add" begin + d1 = a + b + d2 = b + a + @test d1 === d2 + d3 = b - 2 + a + d4 = a + b - 2 + @test d3 === d4 + d5 = Add(Int, 0, Dict(a => 1, b => 1)) + @test d5 !== d1 + + dm1 = setmetadata(d1,Ctx1,"meta_1") + @test d1 !== dm1 +end + +@testset "Mul" begin + m1 = a*b + m2 = b*a + @test m1 === m2 + m3 = 6*a*b + m4 = 3*a*2*b + @test m3 === m4 + m5 = Mul(Int, 1, Dict(a => 1, b => 1)) + @test m5 !== m1 + + mm1 = setmetadata(m1, Ctx1, "meta_1") + @test m1 !== mm1 +end + +@testset "Div" begin + v1 = a/b + v2 = a/b + @test v1 === v2 + v3 = -1/a + v4 = -1/a + @test v3 === v4 + v5 = 3a/6 + v6 = 2a/4 + @test v5 === v6 + v7 = Div{Float64}(-1,a) + @test v7 !== v3 + + vm1 = setmetadata(v1,Ctx1, "meta_1") + @test vm1 !== v1 +end + +@testset "Pow" begin + p1 = a^b + p2 = a^b + @test p1 === p2 + p3 = a^(2^-b) + p4 = a^(2^-b) + @test p3 === p4 + p5 = Pow{Float64}(a,b) + @test p1 !== p5 + + pm1 = setmetadata(p1,Ctx1, "meta_1") + @test pm1 !== p1 +end + +@testset "Equivalent numbers" begin + f = 0.5 + r = 1 // 2 + @test hash(f) == hash(r) + u0 = zero(UInt) + @test hash2(f, u0) != hash2(r, u0) + @test f + a !== r + a +end