diff --git a/Project.toml b/Project.toml index b0ed9c5..c9e0ab7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Legolas" uuid = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" authors = ["Beacon Biosignals, Inc."] -version = "0.5.14" +version = "0.5.15" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" diff --git a/src/schemas.jl b/src/schemas.jl index fb86d95..6185bce 100644 --- a/src/schemas.jl +++ b/src/schemas.jl @@ -447,7 +447,7 @@ end # We maintain an alias to the deprecated name for this type, xref https://github.com/beacon-biosignals/Legolas.jl/pull/100 Base.@deprecate_binding RequiredFieldInfo DeclaredFieldInfo -Base.:(==)(a::DeclaredFieldInfo, b::DeclaredFieldInfo) = all(getfield(a, i) == getfield(b, i) for i in 1:fieldcount(DeclaredFieldInfo)) +Base.:(==)(a::DeclaredFieldInfo, b::DeclaredFieldInfo) = _compare_fields(==, a, b) function _parse_declared_field_info!(f) f isa Symbol && (f = Expr(:(::), f, :Any)) @@ -660,17 +660,6 @@ function _generate_record_type_definitions(schema_version::SchemaVersion, record end end - # generate `base_overload_definitions` - equal_rhs_statement = foldr((x, y) -> :($x && $y), (:(a.$f == b.$f) for f in keys(record_fields))) - isequal_rhs_statement = foldr((x, y) -> :($x && $y), (:(isequal(a.$f, b.$f)) for f in keys(record_fields))) - hash_rhs_statement = foldr((x, y) -> :(hash($x, $y)), (:(r.$f) for f in keys(record_fields)); init=:h) - base_overload_definitions = quote - Base.:(==)(a::$R, b::$R) = $equal_rhs_statement - Base.isequal(a::$R, b::$R) = $isequal_rhs_statement - Base.hash(r::$R, h::UInt) = hash($R, $hash_rhs_statement) - Base.NamedTuple(r::$R) = (; $((:(r.$f) for f in keys(record_fields))...)) - end - # generate `arrow_overload_definitions` record_type_arrow_name = string("JuliaLang.Legolas.Generated.", Legolas.name(schema_version), '.', Legolas.version(schema_version)) record_type_arrow_name = Base.Meta.quot(Symbol(record_type_arrow_name)) @@ -689,7 +678,6 @@ function _generate_record_type_definitions(schema_version::SchemaVersion, record $inner_constructor_definitions end $outer_constructor_definitions - $base_overload_definitions $arrow_overload_definitions $Legolas.record_type(::$(Base.Meta.quot(typeof(schema_version)))) = $R $Legolas.schema_version_from_record(::$R) = $schema_version @@ -863,3 +851,38 @@ macro version(record_type, declared_fields_block) nothing end end + +##### +##### Base overload definitions +##### + +# Field-wise comparison for any two objects with exactly the same type. Most record +# comparisons will hit this method. +function _compare_fields(eq, x::T, y::T) where {T} + return all(i -> eq(getfield(x, i), getfield(y, i)), 1:fieldcount(T)) +end + +# Field-wise comparison of two arbitrary records, with equality contingent on matching +# schemas. Record comparisons for parametric record types with mismatched type parameters +# will hit this method, as well as mismatched record types (which will not compare equal). +function _compare_fields(eq, x::AbstractRecord, y::AbstractRecord) + svx = schema_version_from_record(x) + svy = schema_version_from_record(y) + return svx === svy && all(i -> eq(getfield(x, i), getfield(y, i)), 1:nfields(x)) +end + +Base.:(==)(x::AbstractRecord, y::AbstractRecord) = _compare_fields(==, x, y) + +Base.isequal(x::AbstractRecord, y::AbstractRecord) = _compare_fields(isequal, x, y) + +function Base.hash(r::AbstractRecord, h::UInt) + for i in nfields(r):-1:1 + h = hash(getfield(r, i), h) + end + return hash(typeof(r), h) +end + +function Base.NamedTuple(r::AbstractRecord) + names = fieldnames(typeof(r)) + return NamedTuple{names}(map(x -> getfield(r, x), names)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 365b5ae..161e1bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -295,6 +295,13 @@ end xs::Union{Vector{String},Missing} end +@schema "test.union-missing" UnionMissing + +@version UnionMissingV1 begin + a::Union{Int,Missing} + b::Union{Int,Missing} +end + @testset "`Legolas.@version` and associated utilities for declared `Legolas.SchemaVersion`s" begin @testset "Legolas.SchemaVersionDeclarationError" begin @test_throws SchemaVersionDeclarationError("malformed or missing field declaration(s)") eval(:(@version(NewV1, $(Expr(:block, LineNumberNode(1, :test)))))) @@ -511,6 +518,27 @@ end @test_throws TypeError ParamV1(; i=1.0) @test_throws ArgumentError ParamV1{Int}(; i=1.1) end + + @testset "equality and hashing" begin + a = ChildV1(; x=[1, 2], y="hello", z=missing) + b = ChildV1(; x=[1, 2], y="hello", z=missing) + @test a !== b + @test isequal(a, b) + @test ismissing(a == b) + @test hash(a) == hash(b) + u1 = UnionMissingV1(; a=missing, b=1) + u2 = UnionMissingV1(; a=missing, b=2) + @test u1 != u2 + @test !isequal(u1, u2) + p32 = ParamV1(; i=one(Int32)) + p64 = ParamV1(; i=one(Int64)) + @test p32 == p64 + @test isequal(p32, p64) + 🧑 = ParentV1(; x=[4, 20], y="") + 🧒 = ChildV1(; x=[4, 20], y="", z=missing) + @test 🧑 != 🧒 + @test !isequal(🧑, 🧒) + end end @testset "miscellaneous Legolas/src/tables.jl tests" begin