diff --git a/Project.toml b/Project.toml index c1fe007..761b4fa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FeatureTransforms" uuid = "8fd68953-04b8-4117-ac19-158bf6de9782" authors = ["Invenia Technical Computing Corporation"] -version = "0.3.11" +version = "0.3.12" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index b9b0a22..8e6212b 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -21,6 +21,18 @@ git-tree-sha1 = "f713d583d10fc036252fd826feebc6c173c522a8" uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" version = "0.9.5" +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "f9982ef575e19b0e5c7a98c6e75ee496c0f73a93" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.12.0" + +[[ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.2" + [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956" @@ -83,10 +95,10 @@ deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[FeatureTransforms]] -deps = ["Dates", "NamedDims", "Statistics", "Tables"] +deps = ["Dates", "InteractiveUtils", "NamedDims", "Statistics", "StatsBase", "Tables"] path = ".." uuid = "8fd68953-04b8-4117-ac19-158bf6de9782" -version = "0.3.3" +version = "0.3.12" [[Formatting]] deps = ["Printf"] @@ -108,12 +120,23 @@ version = "0.1.1" deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "a7254c0acd8e62f1ac75ad24d5db43f5f19f3c65" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.2" + [[InvertedIndices]] deps = ["Test"] git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" version = "1.0.0" +[[IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + [[IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" uuid = "82899510-4779-5014-852e-03e436cf321d" @@ -148,6 +171,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "e5718a00af0ab9756305a0392832c8952c7426c1" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.6" + [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -257,6 +286,17 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[[StatsAPI]] +git-tree-sha1 = "d88665adc9bcf45903013af0982e2fd05ae3d0a6" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.2.0" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.16" + [[StructTypes]] deps = ["Dates", "UUIDs"] git-tree-sha1 = "5d8e3d60f17791c4c64baf69a2bc5e7023ee73aa" diff --git a/docs/src/api.md b/docs/src/api.md index a34c30e..90a4c42 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -13,7 +13,7 @@ AbstractScaling HoD Power Periodic -MeanStdScaling +StandardScaling IdentityScaling InverseHyperbolicSine LinearCombination @@ -35,3 +35,8 @@ FeatureTransforms.is_transformable FeatureTransforms.transform! FeatureTransforms.transform ``` + +## Deprecated funtionality +```@docs +MeanStdScaling +``` diff --git a/docs/src/examples.md b/docs/src/examples.md index 2120b03..d2c307c 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -7,6 +7,8 @@ First we load some hourly weather data: ```jldoctest example julia> using DataFrames, Dates, FeatureTransforms +julia> using FeatureTransforms: fit! + julia> df = DataFrame( :time => DateTime(2018, 9, 10):Hour(1):DateTime(2018, 9, 10, 23), :temperature => [10.6, 9.5, 8.9, 8.9, 8.4, 8.4, 7.7, 8.9, 11.7, 13.9, 16.2, 17.7, 18.9, 20.0, 21.2, 21.7, 21.7, 21.2, 20.0, 18.4, 16.7, 15.0, 13.9, 12.7], @@ -76,14 +78,18 @@ julia> test_df = feature_df[end-1:end, :]; julia> output_cols = [:temperature, :humidity]; ``` -For many models it is helpful to normalize the training data. -We can use `MeanStdScaling` for that purpose. +For many models it is helpful to standardise the training data. +We can use `StandardScaling` for that purpose. Note that we are mutating the data frame in-place using `apply!` one column at a time. ```jldoctest example -julia> temp_scaling = MeanStdScaling(train_df; cols=:temperature); +julia> temp_scaling = StandardScaling(); + +julia> fit!(temp_scaling, train_df; cols=:temperature); + +julia> hum_scaling = StandardScaling(); -julia> hum_scaling = MeanStdScaling(train_df; cols=:humidity); +julia> fit!(hum_scaling, train_df; cols=:humidity); julia> FeatureTransforms.apply!(train_df, temp_scaling; cols=:temperature); @@ -111,7 +117,7 @@ julia> FeatureTransforms.apply!(train_df, hum_scaling; cols=:humidity) 7 rows omitted ``` -We can use the same `scaling` transform to normalize the test data: +We can use the same `scaling` transform to standardise the test data: ```jldoctest example julia> FeatureTransforms.apply!(test_df, temp_scaling; cols=:temperature); diff --git a/docs/src/transforms.md b/docs/src/transforms.md index 9f9cb73..be76732 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -2,12 +2,14 @@ A `Transform` defines a transformation of data for feature engineering purposes. Some examples are scaling, periodic functions, linear combination, and one-hot encoding. +Transforms can be stateless, for example the power transform, or they can be stateful and fit to the data, such as the [`StandardScaling`](@ref). ```@meta DocTestSetup = quote using DataFrames using Dates using FeatureTransforms + using FeatureTransforms: fit! end ``` @@ -20,6 +22,15 @@ For example, the following defines a squaring operation (i.e. raise to the power julia> p = Power(2); ``` +A stateful transform, such as a [`StandardScaling`](@ref) should also be fit to the data before it is applied: +```julia-repl +julia> s = StandardScaling(); + +julia> x = rand(5); + +julia> FeatureTransforms.fit!(s, x); +``` + ## Methods to apply a transform Given some data `x`, there are three main methods to apply a transform. @@ -90,8 +101,8 @@ A single `Transform` instance can be applied to different data types, with suppo !!! note Some `Transform` subtypes have restrictions on how they can be applied once constructed. - For instance, `MeanStdScaling` stores the mean and standard deviation of some data, potentially specified via some dimension and column names. - So `MeanStdScaling` should only be applied to the same data, and for the same dimension and subset of column names, as those used in construction. + For instance, `StandardScaling` stores the mean and standard deviation of some data, potentially specified via some dimension and column names. + So `StandardScaling` should only be applied to the same data, and for the same dimension and subset of column names, as those used in construction. ## Applying to `AbstractArray` @@ -144,15 +155,19 @@ julia> M 1.0 5.0 3.0 6.0 -julia> normalize_row = MeanStdScaling(M; dims=1, inds=[2]) -MeanStdScaling(3.0, 2.8284271247461903) +julia> normalize_row = StandardScaling(); + +julia> fit!(normalize_row, M; dims=1, inds=[2]) +StandardScaling(3.0, 2.8284271247461903) julia> normalize_row(M; dims=1, inds=[2]) 1×2 Matrix{Float64}: -0.707107 0.707107 -julia> normalize_col = MeanStdScaling(M; dims=2, inds=[2]) -MeanStdScaling(5.0, 1.0) +julia> normalize_col = StandardScaling(); + +julia> fit!(normalize_col, M; dims=2, inds=[2]) +StandardScaling(5.0, 1.0) julia> normalize_col(M; dims=2, inds=[2]) 3×1 Matrix{Float64}: @@ -172,7 +187,9 @@ If no `header` is given, the default from [`Tables.table`](https://tables.juliad ```jldoctest transforms julia> nt = (a = [2.0, 1.0, 3.0], b = [4.0, 5.0, 6.0]); -julia> scaling = MeanStdScaling(nt); # compute statistics using all data +julia> scaling = StandardScaling(); + +julia> fit!(scaling, nt); # compute statistics using all data julia> FeatureTransforms.apply(nt, scaling; header=[:a_norm, :b_norm]) (a_norm = [-0.8017837257372732, -1.3363062095621219, -0.2672612419124244], b_norm = [0.2672612419124244, 0.8017837257372732, 1.3363062095621219]) @@ -219,12 +236,14 @@ julia> feature_df = hcat(hod_df, lc_df) ## Transform-specific keyword arguments Some transforms have specific keyword arguments that can be passed to `apply`/`apply!`. -For example, `MeanStdScaling` can invert the original scaling using the `inverse` argument: +For example, `StandardScaling` can invert the original scaling using the `inverse` argument: ```jldoctest transforms julia> nt = (a = [2.0, 1.0, 3.0], b = [4.0, 5.0, 6.0]); -julia> scaling = MeanStdScaling(nt); +julia> scaling = StandardScaling(); + +julia> fit!(scaling, nt); julia> FeatureTransforms.apply!(nt, scaling); diff --git a/src/FeatureTransforms.jl b/src/FeatureTransforms.jl index f278731..03e3623 100644 --- a/src/FeatureTransforms.jl +++ b/src/FeatureTransforms.jl @@ -7,13 +7,14 @@ using Tables export Transform, transform, transform! export HoD, LinearCombination, OneHotEncoding, Periodic, Power -export AbstractScaling, IdentityScaling, MeanStdScaling +export AbstractScaling, IdentityScaling, MeanStdScaling, StandardScaling export LogTransform, InverseHyperbolicSine include("utils.jl") include("traits.jl") include("transform.jl") include("apply.jl") +include("fit.jl") # Transform implementations include("linear_combination.jl") @@ -26,7 +27,6 @@ include("temporal.jl") include("test_utils.jl") -# TODO: remove in v0.4 https://github.com/invenia/FeatureTransforms.jl/issues/82 -Base.@deprecate_binding is_transformable TestUtils.is_transformable +include("deprecated.jl") end diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 0000000..82685c3 --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,74 @@ +""" + MeanStdScaling(μ, σ) <: AbstractScaling + +Linearly scale the data by the statistical mean `μ` and standard deviation `σ`. +This is also known as standardization, or the Z score transform. + +# Keyword arguments to `apply` +* `inverse=true`: inverts the scaling (e.g. to reconstruct the unscaled data). +* `eps=1e-3`: used in place of all 0 values in `σ` before scaling (if `inverse=false`). +""" +struct MeanStdScaling <: AbstractScaling + μ::Real + σ::Real + + """ + MeanStdScaling(A::AbstractArray; dims=:, inds=:) -> MeanStdScaling + MeanStdScaling(table, [cols]) -> MeanStdScaling + + Construct a [`MeanStdScaling`](@ref) transform from the statistics of the given data. + By default _all the data_ is considered when computing the mean and standard deviation. + This can be restricted to certain slices via the keyword arguments (see below). + + Since `MeanStdScaling` is a stateful transform, i.e. the parameters depend on the data + it's given, you should define it independently before applying it so you can keep the + information for later use. For instance, if you want to invert the transform or apply it + to a test set. + + # `AbstractArray` keyword arguments + * `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims. + * `inds=:`: the indices to use in computing the statistics. Default uses all indices. + + # `Table` keyword arguments + * `cols`: the columns to use in computing the statistics. Default uses all columns. + + !!! note + If you want the `MeanStdScaling` to transform your data consistently you should use + the same `inds`, `dims`, or `cols` keywords when calling `apply`. Otherwise, `apply` + might rescale the wrong data or throw an error. + """ + function MeanStdScaling(A::AbstractArray; dims=:, inds=:) + _depwarn() + dims == Colon() && return new(compute_stats(A)...) + return new(compute_stats(selectdim(A, dims, inds))...) + end + + function MeanStdScaling(table; cols=_get_cols(table)) + _depwarn() + Tables.istable(table) || throw(MethodError(MeanStdScaling, table)) + columntable = Tables.columns(table) + data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)]) + return new(compute_stats(data)...) + end +end + +function _depwarn() + Base.depwarn( + "`MeanStdScaling(args...; kwargs...)` is deprecated. Use " * + "`ss = StandardScaling(); fit!(scaling, args...; kwargs...)` instead", + :MeanStdScaling + ) + return nothing +end + +function _apply(A::AbstractArray, scaling::MeanStdScaling; inverse=false, eps=1e-3, kwargs...) + inverse && return scaling.μ .+ scaling.σ .* A + # Avoid division by 0 + # If std is 0 then data was uniform, so the scaled value would end up ≈ 0 + # Therefore the particular `eps` value should not matter much. + σ_safe = max(scaling.σ, eps) + return (A .- scaling.μ) ./ σ_safe +end + +# TODO: remove in v0.4 https://github.com/invenia/FeatureTransforms.jl/issues/82 +Base.@deprecate_binding is_transformable TestUtils.is_transformable diff --git a/src/fit.jl b/src/fit.jl new file mode 100644 index 0000000..cb1f801 --- /dev/null +++ b/src/fit.jl @@ -0,0 +1,20 @@ +""" + fit!(transform::Transform, data::AbstractArray; dims=:, inds=:) + fit!(transform::Transform, table, [cols]) + +Fit the transform to the given data. By default _all the data_ is considered. +This can be restricted to certain slices via the keyword arguments (see below). + +# `AbstractArray` keyword arguments +* `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims. +* `inds=:`: the indices to use in computing the statistics. Default uses all indices. + +# `Table` keyword arguments +* `cols`: the columns to use in computing the statistics. Default uses all columns. + +!!! note + If you want to transform your data consistently you should use the same `inds`, `dims`, + or `cols` keywords when calling `apply`. Otherwise, `apply` might rescale the wrong + data or throw an error. +""" +fit!(t::Transform, args...; kwargs...) = return t diff --git a/src/scaling.jl b/src/scaling.jl index 35e9a9d..53e7e3a 100644 --- a/src/scaling.jl +++ b/src/scaling.jl @@ -17,64 +17,50 @@ IdentityScaling(args...) = IdentityScaling() @inline _apply(x, ::IdentityScaling; kwargs...) = x """ - MeanStdScaling(μ, σ) <: AbstractScaling + StandardScaling <: AbstractScaling -Linearly scale the data by the statistical mean `μ` and standard deviation `σ`. -This is also known as standardization, or the Z score transform. +Transforms the data according to -# Keyword arguments to `apply` -* `inverse=true`: inverts the scaling (e.g. to reconstruct the unscaled data). -* `eps=1e-3`: used in place of all 0 values in `σ` before scaling (if `inverse=false`). -""" -struct MeanStdScaling <: AbstractScaling - μ::Real - σ::Real - - """ - MeanStdScaling(A::AbstractArray; dims=:, inds=:) -> MeanStdScaling - MeanStdScaling(table, [cols]) -> MeanStdScaling - - Construct a [`MeanStdScaling`](@ref) transform from the statistics of the given data. - By default _all the data_ is considered when computing the mean and standard deviation. - This can be restricted to certain slices via the keyword arguments (see below). + x -> (x - μ) / σ - Since `MeanStdScaling` is a stateful transform, i.e. the parameters depend on the data - it's given, you should define it independently before applying it so you can keep the - information for later use. For instance, if you want to invert the transform or apply it - to a test set. +where μ and σ are the mean and standard deviation of the training data. - # `AbstractArray` keyword arguments - * `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims. - * `inds=:`: the indices to use in computing the statistics. Default uses all indices. - - # `Table` keyword arguments - * `cols`: the columns to use in computing the statistics. Default uses all columns. +!!! note + `fit!(scaling, data)` needs to be called before the transform can be `apply`ed. + By default _all the data_ is considered when `fit!`ing the mean and standard deviation. +""" +mutable struct StandardScaling <: AbstractScaling + μ::Union{Real, Nothing} + σ::Union{Real, Nothing} - !!! note - If you want the `MeanStdScaling` to transform your data consistently you should use - the same `inds`, `dims`, or `cols` keywords when calling `apply`. Otherwise, `apply` - might rescale the wrong data or throw an error. - """ - function MeanStdScaling(A::AbstractArray; dims=:, inds=:) - dims == Colon() && return new(compute_stats(A)...) - return new(compute_stats(selectdim(A, dims, inds))...) - end + StandardScaling() = return new(nothing, nothing) +end - function MeanStdScaling(table; cols=_get_cols(table)) - Tables.istable(table) || throw(MethodError(MeanStdScaling, table)) - columntable = Tables.columns(table) - data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)]) - return new(compute_stats(data)...) - end +function fit!(ss::StandardScaling, args...; kwargs...) + ss.μ isa Nothing || throw(ErrorException("StandardScaling should not be refit.")) + ss.μ, ss.σ = _fit(ss, args...; kwargs...) + return ss end compute_stats(x) = (mean(x), std(x)) -function _apply(A::AbstractArray, scaling::MeanStdScaling; inverse=false, eps=1e-3, kwargs...) - inverse && return scaling.μ .+ scaling.σ .* A +function _fit(::StandardScaling, data::AbstractArray; dims=:, inds=:) + dims isa Colon && return compute_stats(data) + compute_stats(selectdim(data, dims, inds)) +end +function _fit(::StandardScaling, table; cols=_get_cols(table)) + Tables.istable(table) || throw(MethodError(StandardScaling, table)) + columntable = Tables.columns(table) + data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)]) + return compute_stats(data) +end + +function _apply(A::AbstractArray, ss::StandardScaling; inverse=false, eps=1e-3, kwargs...) + ss.μ isa Real || throw(ErrorException("`fit!` StandardScaling before applying.")) + inverse && return ss.μ .+ ss.σ .* A # Avoid division by 0 # If std is 0 then data was uniform, so the scaled value would end up ≈ 0 # Therefore the particular `eps` value should not matter much. - σ_safe = max(scaling.σ, eps) - return (A .- scaling.μ) ./ σ_safe + σ_safe = max(ss.σ, eps) + return (A .- ss.μ) ./ σ_safe end diff --git a/test/deprecated.jl b/test/deprecated.jl new file mode 100644 index 0000000..5ce8789 --- /dev/null +++ b/test/deprecated.jl @@ -0,0 +1,65 @@ +@testset "deprecated.jl" begin + + @testset "MeanStdScaling" begin + @testset "Constructor" begin + M = [0.0 -0.5 0.5; 0.0 1.0 2.0] + nt = (a = [0.0, -0.5, 0.5], b = [1.0, 0.0, 2.0]) + + @testset "simple" for x in (M, nt) + x_copy = deepcopy(x) + @test_deprecated scaling = MeanStdScaling(x) + @test cardinality(scaling) == OneToOne() + @test scaling isa Transform + @test x == x_copy # data is not mutated + # constructor uses all data by default + @test scaling.μ == 0.5 + @test scaling.σ ≈ 0.89443 atol=1e-5 + end + + @testset "use certain slices to compute statistics" begin + @testset "Array" begin + scaling = MeanStdScaling(M; dims=1, inds=[2]) + @test scaling.μ == 1.0 + @test scaling.σ == 1.0 + end + + @testset "Table" begin + scaling = MeanStdScaling(nt; cols=:a) + @test scaling.μ == 0.0 + @test scaling.σ == 0.5 + end + end + end + + @testset "Re-apply" begin + M = [0.0 -0.5 0.5; 0.0 1.0 2.0] + scaling = MeanStdScaling(M; dims=2) + new_M = [1.0 -2.0 -1.0; 0.5 0.0 0.5] + @test M !== new_M + # Expect scaling parameters to be fixed to the first data applied to + expected_reapply = [0.559017 -2.79508 -1.67705; 0.0 -0.55901 0.0] + @test FeatureTransforms.apply(new_M, scaling; dims=2) ≈ expected_reapply atol=1e-5 + end + + @testset "Inverse" begin + M = [0.0 -0.5 0.5; 0.0 1.0 2.0] + M_expected = [-0.559017 -1.11803 0.0; -0.559017 0.559017 1.67705] + scaling = MeanStdScaling(M) + transformed = FeatureTransforms.apply(M, scaling) + + @test transformed ≈ M_expected atol=1e-5 + @test FeatureTransforms.apply(transformed, scaling; inverse=true) ≈ M atol=1e-5 + end + + @testset "Zero std" begin + x = ones(Float64, 3) + expected = zeros(Float64, 3) + + scaling = MeanStdScaling(x) + + @test FeatureTransforms.apply(x, scaling) == expected # default `eps` + @test FeatureTransforms.apply(x, scaling; eps=1) == expected + @test all(isnan.(FeatureTransforms.apply(x, scaling; eps=0))) # 0/0 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index eaa447c..61a1d55 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Documenter: doctest using FeatureTransforms using FeatureTransforms: _periodic using FeatureTransforms: cardinality, OneToOne, OneToMany, ManyToOne, ManyToMany +using FeatureTransforms: fit! using FeatureTransforms.TestUtils using Tables: columntable, isrowtable, istable, rowtable using Test @@ -31,4 +32,6 @@ using TimeZones include("types/matrix.jl") include("types/cube.jl") include("types/xarray.jl") + + include("deprecated.jl") end diff --git a/test/scaling.jl b/test/scaling.jl index 2846b27..6ad603c 100644 --- a/test/scaling.jl +++ b/test/scaling.jl @@ -11,16 +11,25 @@ end end - @testset "MeanStdScaling" begin - @testset "Constructor" begin + @testset "StandardScaling" begin + @testset "constructor" begin + ss = StandardScaling() + @test (ss.μ, ss.σ) == (nothing, nothing) + @test cardinality(ss) == OneToOne() + @test ss isa Transform + + @test_throws MethodError StandardScaling(0.0, 1.0, false) + end + + @testset "fit!" begin M = [0.0 -0.5 0.5; 0.0 1.0 2.0] nt = (a = [0.0, -0.5, 0.5], b = [1.0, 0.0, 2.0]) @testset "simple" for x in (M, nt) + scaling = StandardScaling() x_copy = deepcopy(x) - scaling = MeanStdScaling(x) - @test cardinality(scaling) == OneToOne() - @test scaling isa Transform + + fit!(scaling, x_copy) @test x == x_copy # data is not mutated # constructor uses all data by default @test scaling.μ == 0.5 @@ -29,22 +38,32 @@ @testset "use certain slices to compute statistics" begin @testset "Array" begin - scaling = MeanStdScaling(M; dims=1, inds=[2]) + scaling = StandardScaling() + fit!(scaling, M; dims=1, inds=[2]) @test scaling.μ == 1.0 @test scaling.σ == 1.0 end @testset "Table" begin - scaling = MeanStdScaling(nt; cols=:a) + scaling = StandardScaling() + fit!(scaling, nt; cols=:a) @test scaling.μ == 0.0 @test scaling.σ == 0.5 end end + + @testset "refit" begin + x = rand(10) + scaling = StandardScaling() + fit!(scaling, x) + @test_throws ErrorException fit!(scaling, x) + end end @testset "Re-apply" begin M = [0.0 -0.5 0.5; 0.0 1.0 2.0] - scaling = MeanStdScaling(M; dims=2) + scaling = StandardScaling() + fit!(scaling, M; dims=2) new_M = [1.0 -2.0 -1.0; 0.5 0.0 0.5] @test M !== new_M # Expect scaling parameters to be fixed to the first data applied to @@ -55,7 +74,8 @@ @testset "Inverse" begin M = [0.0 -0.5 0.5; 0.0 1.0 2.0] M_expected = [-0.559017 -1.11803 0.0; -0.559017 0.559017 1.67705] - scaling = MeanStdScaling(M) + scaling = StandardScaling() + fit!(scaling, M) transformed = FeatureTransforms.apply(M, scaling) @test transformed ≈ M_expected atol=1e-5 @@ -66,7 +86,8 @@ x = ones(Float64, 3) expected = zeros(Float64, 3) - scaling = MeanStdScaling(x) + scaling = StandardScaling() + fit!(scaling, x) @test FeatureTransforms.apply(x, scaling) == expected # default `eps` @test FeatureTransforms.apply(x, scaling; eps=1) == expected diff --git a/test/test_utils.jl b/test/test_utils.jl index 714b9e8..cb38397 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -16,6 +16,7 @@ result = FeatureTransforms.apply(M, t) @test result == expected @test result isa KeyedArray + @test deepcopy(t) == fit!(t) end @testset "FakeOneToManyTransform" begin @@ -34,6 +35,7 @@ result = FeatureTransforms.apply(M, t) @test result == expected @test result isa KeyedArray + @test deepcopy(t) == fit!(t) end @testset "FakeManyToOneTransform" begin @@ -52,6 +54,7 @@ result = FeatureTransforms.apply(M, t; dims=:b) @test result == expected @test result isa KeyedArray + @test deepcopy(t) == fit!(t) end @testset "FakeManyToManyTransform" begin @@ -70,6 +73,7 @@ result = FeatureTransforms.apply(M, t) @test result == expected @test result isa KeyedArray + @test deepcopy(t) == fit!(t) end