Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hash consing for Sym #658

Merged
merged 20 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d50294b
Add WeakValueDicts
bowenszhu Oct 14, 2024
7a3d8ba
Import `WeakValueDict`
bowenszhu Oct 16, 2024
e9e702b
Construct SymbolicUtils internal `WeakValueDict`
bowenszhu Oct 16, 2024
62403b9
Change `BasicSymbolic` types to `mutable` due to Julia finalizer
bowenszhu Oct 16, 2024
7e93ca4
Define hash extension function for incorporating symtype
bowenszhu Oct 16, 2024
9be3d03
Hash consing for `Sym`
bowenszhu Oct 16, 2024
93c2837
Merge remote-tracking branch 'origin/master' into hash-consing
bowenszhu Oct 17, 2024
f1a9a93
Test hash consing for `Sym` with different symtypes
bowenszhu Oct 18, 2024
70de918
Feat: Incorporate `metadata` into `BasicSymbolic` hash computation
bowenszhu Oct 25, 2024
2dac2a3
Apply hash consing also when there is metadata
bowenszhu Oct 25, 2024
c2d85c3
Create flyweight factory for `BasicSymbolic`
bowenszhu Oct 25, 2024
8957290
Add `isequal2` function for checking metadata comparison
bowenszhu Nov 5, 2024
2779856
Handle hash collision with customized `isequal2`
bowenszhu Nov 5, 2024
d36198f
Call outer constructor for `Sym` in `ConstructionBase.setproperties`
bowenszhu Nov 5, 2024
cf937b0
Test hash consing for `Sym` with metadata
bowenszhu Nov 5, 2024
765293a
Add docstring for the flyweight factory function
bowenszhu Nov 5, 2024
84a0596
Add docstring for `hash2`
bowenszhu Nov 5, 2024
187ce45
Add comment explaining why calling outer constructor in `setproperties`
bowenszhu Nov 5, 2024
55ca2ec
Refactor: Make `wvd` a constant global
bowenszhu Nov 6, 2024
13b642b
Rename the `isequal2` function to `isequal_with_metadata` for clarity.
bowenszhu Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Expand Down Expand Up @@ -57,6 +58,7 @@ SymbolicIndexingInterface = "0.3"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
WeakValueDicts = "0.1.0"
julia = "1.3"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm, sorted_arguments
# For ReverseDiffExt
import ArrayInterface
using WeakValueDicts: WeakValueDict

Base.@deprecate istree iscall
export istree, operation, arguments, sorted_arguments, iscall
Expand Down
30 changes: 21 additions & 9 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,38 @@ const EMPTY_DICT = sdict()
const EMPTY_DICT_T = typeof(EMPTY_DICT)

@compactify show_methods=false begin
@abstract struct BasicSymbolic{T} <: Symbolic{T}
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
metadata::Metadata = NO_METADATA
end
struct Sym{T} <: BasicSymbolic{T}
mutable struct Sym{T} <: BasicSymbolic{T}
name::Symbol = :OOF
end
struct Term{T} <: BasicSymbolic{T}
mutable struct Term{T} <: BasicSymbolic{T}
f::Any = identity # base/num if Pow; issorted if Add/Dict
arguments::Vector{Any} = EMPTY_ARGS
hash::RefValue{UInt} = EMPTY_HASH
end
struct Mul{T} <: BasicSymbolic{T}
mutable struct Mul{T} <: BasicSymbolic{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
struct Add{T} <: BasicSymbolic{T}
mutable struct Add{T} <: BasicSymbolic{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
struct Div{T} <: BasicSymbolic{T}
mutable struct Div{T} <: BasicSymbolic{T}
num::Any = 1
den::Any = 1
simplified::Bool = false
arguments::Vector{Any} = EMPTY_ARGS
end
struct Pow{T} <: BasicSymbolic{T}
mutable struct Pow{T} <: BasicSymbolic{T}
base::Any = 1
exp::Any = 1
arguments::Vector{Any} = EMPTY_ARGS
Expand All @@ -77,6 +77,8 @@ function exprtype(x::BasicSymbolic)
end
end

wvd = WeakValueDict{UInt, BasicSymbolic}()
bowenszhu marked this conversation as resolved.
Show resolved Hide resolved

# Same but different error messages
@noinline error_on_type() = error("Internal error: unreachable reached!")
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
Expand Down Expand Up @@ -307,12 +309,22 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
end
end

hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
hash(T, hash(s, salt))
end

###
### Constructors
###

function Sym{T}(name::Symbol; kw...) where T
Sym{T}(; name=name, kw...)
function Sym{T}(name::Symbol; metadata = NO_METADATA, kw...) where {T}
if metadata==NO_METADATA
s = Sym{T}(; name, kw...)
get!(wvd, hash2(s), s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a dictionary that takes both hash and equal to handle collision.

else
Sym{T}(; name, metadata, kw...)
end
end

function Term{T}(f, args; kw...) where T
Expand Down
17 changes: 17 additions & 0 deletions test/hash_consing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using SymbolicUtils, Test

@testset "Sym" begin
x1 = only(@syms x)
x2 = only(@syms x)
@test x1 === x2
x3 = only(@syms x::Float64)
@test x1 !== x3
x4 = only(@syms x::Float64)
@test x1 !== x4
@test x3 === x4
x5 = only(@syms x::Int)
x6 = only(@syms x::Int)
@test x1 !== x5
@test x3 !== x5
@test x5 === x6
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ using Pkg, Test, SafeTestsets
# Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed
@safetestset "Fuzz" begin include("fuzz.jl") end
@safetestset "Adjoints" begin include("adjoints.jl") end
@safetestset "Hash Consing" begin include("hash_consing.jl") end
end
end
Loading