From aec72ce5ba3101e27b54acccbc214ae75c993ab1 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 29 Sep 2023 10:55:19 -0700 Subject: [PATCH 1/8] Use common methods for Base functions instead of generating Currently, methods for `==`, `isequal`, `hash`, and `NamedTuple` are generated by the `@version` macro for every row type that gets defined. The bodies of these methods always follow the same patterns, none of which require any information that's only available in the context where the definitions are generated. In fact, specific methods don't need to be generated at all; `AbstractRecord`s can all use the same generic methods for these functions. This provides the following additional benefits: - When many schema versions are defined, this significantly reduces the excessive number of redundant methods which make method autocompletion effectively useless. - Defining the generic methods in terms of regular functions rather than generating expressions in the macro means that it's easier to reason about the methods' behavior and ensure overall consistency. For example, moving from generating a chain of `==` and `&&` over a record's fields to directly using `all` means that `missing` is treated consistently (see issue 101). Note that the generic method for `hash` as defined here loops over the fields in reverse order. This is to match the previous `foldr` behavior, ensuring that hashes don't change with this implementation. --- Project.toml | 2 +- src/schemas.jl | 36 ++++++++++++++++++++++++------------ test/runtests.jl | 9 +++++++++ 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index b0ed9c5..c9e0ab7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Legolas" uuid = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" authors = ["Beacon Biosignals, Inc."] -version = "0.5.14" +version = "0.5.15" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" diff --git a/src/schemas.jl b/src/schemas.jl index fb86d95..0c3ebb6 100644 --- a/src/schemas.jl +++ b/src/schemas.jl @@ -660,17 +660,6 @@ function _generate_record_type_definitions(schema_version::SchemaVersion, record end end - # generate `base_overload_definitions` - equal_rhs_statement = foldr((x, y) -> :($x && $y), (:(a.$f == b.$f) for f in keys(record_fields))) - isequal_rhs_statement = foldr((x, y) -> :($x && $y), (:(isequal(a.$f, b.$f)) for f in keys(record_fields))) - hash_rhs_statement = foldr((x, y) -> :(hash($x, $y)), (:(r.$f) for f in keys(record_fields)); init=:h) - base_overload_definitions = quote - Base.:(==)(a::$R, b::$R) = $equal_rhs_statement - Base.isequal(a::$R, b::$R) = $isequal_rhs_statement - Base.hash(r::$R, h::UInt) = hash($R, $hash_rhs_statement) - Base.NamedTuple(r::$R) = (; $((:(r.$f) for f in keys(record_fields))...)) - end - # generate `arrow_overload_definitions` record_type_arrow_name = string("JuliaLang.Legolas.Generated.", Legolas.name(schema_version), '.', Legolas.version(schema_version)) record_type_arrow_name = Base.Meta.quot(Symbol(record_type_arrow_name)) @@ -689,7 +678,6 @@ function _generate_record_type_definitions(schema_version::SchemaVersion, record $inner_constructor_definitions end $outer_constructor_definitions - $base_overload_definitions $arrow_overload_definitions $Legolas.record_type(::$(Base.Meta.quot(typeof(schema_version)))) = $R $Legolas.schema_version_from_record(::$R) = $schema_version @@ -863,3 +851,27 @@ macro version(record_type, declared_fields_block) nothing end end + +##### +##### Base overload definitions +##### + +function Base.:(==)(x::T, y::T) where {T<:AbstractRecord} + return all(i -> getfield(x, i) == getfield(y, i), 1:fieldcount(T)) +end + +function Base.isequal(x::T, y::T) where {T<:AbstractRecord} + return all(i -> isequal(getfield(x, i), getfield(y, i)), 1:fieldcount(T)) +end + +function Base.hash(r::AbstractRecord, h::UInt) + for i in nfields(r):-1:1 + h = hash(getfield(r, i), h) + end + return hash(typeof(r), h) +end + +function Base.NamedTuple(r::AbstractRecord) + names = fieldnames(typeof(r)) + return NamedTuple{names}(map(x -> getfield(r, x), names)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 365b5ae..d985c2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -511,6 +511,15 @@ end @test_throws TypeError ParamV1(; i=1.0) @test_throws ArgumentError ParamV1{Int}(; i=1.1) end + + @testset "equality and hashing" begin + c = ChildV1(; x=[1, 2], y="hello", z=missing) + @test isequal(c, c) + @test ismissing(c == c) + if UInt === UInt64 # value will be different depending on system word size + @test hash(c) === 0x07055951b3aa478e + end + end end @testset "miscellaneous Legolas/src/tables.jl tests" begin From 639ec146616d6fbd7ce7428d6cd40544451784b8 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 29 Sep 2023 11:29:35 -0700 Subject: [PATCH 2/8] Don't test hash value directly --- test/runtests.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d985c2e..7917892 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -516,9 +516,7 @@ end c = ChildV1(; x=[1, 2], y="hello", z=missing) @test isequal(c, c) @test ismissing(c == c) - if UInt === UInt64 # value will be different depending on system word size - @test hash(c) === 0x07055951b3aa478e - end + @test hash(c) isa UInt # NOTE: can't rely on particular values end end From ee86e096868544e2e3ad7f2ad08594dd52dbf711 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 29 Sep 2023 13:25:50 -0700 Subject: [PATCH 3/8] Add (nearly) exact test case from issue 101 --- test/runtests.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 7917892..142fd73 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -295,6 +295,13 @@ end xs::Union{Vector{String},Missing} end +@schema "test.union-missing" UnionMissing + +@version UnionMissingV1 begin + a::Union{Int,Missing} + b::Union{Int,Missing} +end + @testset "`Legolas.@version` and associated utilities for declared `Legolas.SchemaVersion`s" begin @testset "Legolas.SchemaVersionDeclarationError" begin @test_throws SchemaVersionDeclarationError("malformed or missing field declaration(s)") eval(:(@version(NewV1, $(Expr(:block, LineNumberNode(1, :test)))))) @@ -517,6 +524,8 @@ end @test isequal(c, c) @test ismissing(c == c) @test hash(c) isa UInt # NOTE: can't rely on particular values + @test UnionMissingV1(; a=missing, b=1) != UnionMissingV1(; a=missing, b=2) + @test !isequal(UnionMissingV1(; a=missing, b=1), UnionMissingV1(; a=missing, b=2)) end end From 5d129936ac11ca74d6b0811ae46d1d7283ff5515 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 29 Sep 2023 18:28:50 -0700 Subject: [PATCH 4/8] This is honestly pretty cursed but hear me out We want to compare by field, ignoring any type parameters, but not allowing comparison of differing row types. That kind of type query can't be specified via dispatch, but we can add an additional check on the equality of the row types. This is done in a, uh, creative way. ```julia julia> @schema "test.param" Param julia> @version ParamV1 begin x::(<:Integer) end julia> x = ParamV1(; i=one(Int32)); julia> typeof(x) ParamV1{Int32} julia> Base.typename(ans) typename(ParamV1) julia> ans.wrapper ParamV1 julia> Base.unwrap_unionall(ans) ParamV1{_I<:Integer} ``` By applying this transformation to the types of two records, we can compare without considering the type parameter. And since we know that the types of the records are the same, we can use `nfields(x)` with the peace of mind that `nfields(y)` will be the same. --- src/schemas.jl | 15 +++++++++++---- test/runtests.jl | 2 ++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/schemas.jl b/src/schemas.jl index 0c3ebb6..cb3994a 100644 --- a/src/schemas.jl +++ b/src/schemas.jl @@ -856,12 +856,19 @@ end ##### Base overload definitions ##### -function Base.:(==)(x::T, y::T) where {T<:AbstractRecord} - return all(i -> getfield(x, i) == getfield(y, i), 1:fieldcount(T)) +_typeof(r::AbstractRecord) = Base.unwrap_unionall(Base.typename(typeof(r)).wrapper) + +_type_equal(x::R, y::R) where {R<:AbstractRecord} = true +_type_equal(x::AbstractRecord, y::AbstractRecord) = _typeof(x) === _typeof(y) + +function Base.:(==)(x::AbstractRecord, y::AbstractRecord) + _type_equal(x, y) || return false + return all(i -> getfield(x, i) == getfield(y, i), 1:nfields(x)) end -function Base.isequal(x::T, y::T) where {T<:AbstractRecord} - return all(i -> isequal(getfield(x, i), getfield(y, i)), 1:fieldcount(T)) +function Base.isequal(x::AbstractRecord, y::AbstractRecord) + _type_equal(x, y) || return false + return all(i -> isequal(getfield(x, i), getfield(y, i)), 1:nfields(x)) end function Base.hash(r::AbstractRecord, h::UInt) diff --git a/test/runtests.jl b/test/runtests.jl index 142fd73..877510a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -526,6 +526,8 @@ end @test hash(c) isa UInt # NOTE: can't rely on particular values @test UnionMissingV1(; a=missing, b=1) != UnionMissingV1(; a=missing, b=2) @test !isequal(UnionMissingV1(; a=missing, b=1), UnionMissingV1(; a=missing, b=2)) + @test ParamV1(; i=one(Int32)) == ParamV1(; i=one(Int64)) + @test isequal(ParamV1(; i=one(Int32)), ParamV1(; i=one(Int64))) end end From 88288f57ffa5d0474ed86ad7615e9bcac141216b Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Sat, 30 Sep 2023 13:54:12 -0700 Subject: [PATCH 5/8] Use Legolas functions instead of Base internals I don't want to go to hell just yet --- src/schemas.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schemas.jl b/src/schemas.jl index cb3994a..89a32e2 100644 --- a/src/schemas.jl +++ b/src/schemas.jl @@ -856,7 +856,7 @@ end ##### Base overload definitions ##### -_typeof(r::AbstractRecord) = Base.unwrap_unionall(Base.typename(typeof(r)).wrapper) +_typeof(r::AbstractRecord) = record_type(schema_version_from_record(r)) _type_equal(x::R, y::R) where {R<:AbstractRecord} = true _type_equal(x::AbstractRecord, y::AbstractRecord) = _typeof(x) === _typeof(y) From c7792ded2aa6843730fe3dfde46c4ff41d1f7a10 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Sat, 30 Sep 2023 14:34:07 -0700 Subject: [PATCH 6/8] Simplify further --- src/schemas.jl | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/schemas.jl b/src/schemas.jl index 89a32e2..6185bce 100644 --- a/src/schemas.jl +++ b/src/schemas.jl @@ -447,7 +447,7 @@ end # We maintain an alias to the deprecated name for this type, xref https://github.com/beacon-biosignals/Legolas.jl/pull/100 Base.@deprecate_binding RequiredFieldInfo DeclaredFieldInfo -Base.:(==)(a::DeclaredFieldInfo, b::DeclaredFieldInfo) = all(getfield(a, i) == getfield(b, i) for i in 1:fieldcount(DeclaredFieldInfo)) +Base.:(==)(a::DeclaredFieldInfo, b::DeclaredFieldInfo) = _compare_fields(==, a, b) function _parse_declared_field_info!(f) f isa Symbol && (f = Expr(:(::), f, :Any)) @@ -856,21 +856,25 @@ end ##### Base overload definitions ##### -_typeof(r::AbstractRecord) = record_type(schema_version_from_record(r)) - -_type_equal(x::R, y::R) where {R<:AbstractRecord} = true -_type_equal(x::AbstractRecord, y::AbstractRecord) = _typeof(x) === _typeof(y) - -function Base.:(==)(x::AbstractRecord, y::AbstractRecord) - _type_equal(x, y) || return false - return all(i -> getfield(x, i) == getfield(y, i), 1:nfields(x)) +# Field-wise comparison for any two objects with exactly the same type. Most record +# comparisons will hit this method. +function _compare_fields(eq, x::T, y::T) where {T} + return all(i -> eq(getfield(x, i), getfield(y, i)), 1:fieldcount(T)) end -function Base.isequal(x::AbstractRecord, y::AbstractRecord) - _type_equal(x, y) || return false - return all(i -> isequal(getfield(x, i), getfield(y, i)), 1:nfields(x)) +# Field-wise comparison of two arbitrary records, with equality contingent on matching +# schemas. Record comparisons for parametric record types with mismatched type parameters +# will hit this method, as well as mismatched record types (which will not compare equal). +function _compare_fields(eq, x::AbstractRecord, y::AbstractRecord) + svx = schema_version_from_record(x) + svy = schema_version_from_record(y) + return svx === svy && all(i -> eq(getfield(x, i), getfield(y, i)), 1:nfields(x)) end +Base.:(==)(x::AbstractRecord, y::AbstractRecord) = _compare_fields(==, x, y) + +Base.isequal(x::AbstractRecord, y::AbstractRecord) = _compare_fields(isequal, x, y) + function Base.hash(r::AbstractRecord, h::UInt) for i in nfields(r):-1:1 h = hash(getfield(r, i), h) From 2cf0556ac591be906f61fba7f163241d729255c5 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 2 Oct 2023 13:10:57 -0700 Subject: [PATCH 7/8] Make tests better Co-authored-by: Curtis Vogt --- test/runtests.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 877510a..caaa41b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -520,10 +520,12 @@ end end @testset "equality and hashing" begin - c = ChildV1(; x=[1, 2], y="hello", z=missing) - @test isequal(c, c) - @test ismissing(c == c) - @test hash(c) isa UInt # NOTE: can't rely on particular values + a = ChildV1(; x=[1, 2], y="hello", z=missing) + b = ChildV1(; x=[1, 2], y="hello", z=missing) + @test a !== b + @test isequal(a, b) + @test ismissing(a == b) + @test hash(a) == hash(b) @test UnionMissingV1(; a=missing, b=1) != UnionMissingV1(; a=missing, b=2) @test !isequal(UnionMissingV1(; a=missing, b=1), UnionMissingV1(; a=missing, b=2)) @test ParamV1(; i=one(Int32)) == ParamV1(; i=one(Int64)) From 1d054bd2162a61f1cac040d5c46c46b1c345bea7 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 2 Oct 2023 15:41:23 -0700 Subject: [PATCH 8/8] Better formatted tests + add parent/child equality tests --- test/runtests.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index caaa41b..161e1bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -526,10 +526,18 @@ end @test isequal(a, b) @test ismissing(a == b) @test hash(a) == hash(b) - @test UnionMissingV1(; a=missing, b=1) != UnionMissingV1(; a=missing, b=2) - @test !isequal(UnionMissingV1(; a=missing, b=1), UnionMissingV1(; a=missing, b=2)) - @test ParamV1(; i=one(Int32)) == ParamV1(; i=one(Int64)) - @test isequal(ParamV1(; i=one(Int32)), ParamV1(; i=one(Int64))) + u1 = UnionMissingV1(; a=missing, b=1) + u2 = UnionMissingV1(; a=missing, b=2) + @test u1 != u2 + @test !isequal(u1, u2) + p32 = ParamV1(; i=one(Int32)) + p64 = ParamV1(; i=one(Int64)) + @test p32 == p64 + @test isequal(p32, p64) + 🧑 = ParentV1(; x=[4, 20], y="") + 🧒 = ChildV1(; x=[4, 20], y="", z=missing) + @test 🧑 != 🧒 + @test !isequal(🧑, 🧒) end end