diff --git a/Project.toml b/Project.toml index 5dec6c4..554f125 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FeatureTransforms" uuid = "8fd68953-04b8-4117-ac19-158bf6de9782" authors = ["Invenia Technical Computing Corporation"] -version = "0.3.6" +version = "0.3.7" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/src/scaling.jl b/src/scaling.jl index 3fd4b7b..5cdfe06 100644 --- a/src/scaling.jl +++ b/src/scaling.jl @@ -31,8 +31,8 @@ struct MeanStdScaling <: AbstractScaling σ::Real """ - MeanStdScaling(A::AbstractArray; dims=:, inds=:) -> MeanStdScaling - MeanStdScaling(table, [cols]) -> MeanStdScaling + MeanStdScaling(A::AbstractArray; dims=:, inds=:, corrected=true) -> MeanStdScaling + MeanStdScaling(table, [cols], corrected=true) -> MeanStdScaling Construct a [`MeanStdScaling`](@ref) transform from the statistics of the given data. By default _all the data_ is considered when computing the mean and standard deviation. @@ -46,29 +46,31 @@ struct MeanStdScaling <: AbstractScaling # `AbstractArray` keyword arguments * `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims. * `inds=:`: the indices to use in computing the statistics. Default uses all indices. + * `corrected=true`: passed to `Statistics.std`. # `Table` keyword arguments * `cols`: the columns to use in computing the statistics. Default uses all columns. + * `corrected=true`: passed to `Statistics.std`. !!! note If you want the `MeanStdScaling` to transform your data consistently you should use the same `inds`, `dims`, or `cols` keywords when calling `apply`. Otherwise, `apply` might rescale the wrong data or throw an error. """ - function MeanStdScaling(A::AbstractArray; dims=:, inds=:) - dims == Colon() && return new(compute_stats(A)...) - return new(compute_stats(selectdim(A, dims, inds))...) + function MeanStdScaling(A::AbstractArray; dims=:, inds=:, corrected=true) + dims == Colon() && return new(compute_stats(A; corrected=corrected)...) + return new(compute_stats(selectdim(A, dims, inds); corrected=corrected)...) end - function MeanStdScaling(table; cols=_get_cols(table)) + function MeanStdScaling(table; cols=_get_cols(table), corrected=true) Tables.istable(table) || throw(MethodError(MeanStdScaling, table)) columntable = Tables.columns(table) data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)]) - return new(compute_stats(data)...) + return new(compute_stats(data; corrected=corrected)...) end end -compute_stats(x) = (mean(x), std(x)) +compute_stats(x; corrected) = (mean(x), std(x; corrected=corrected)) function _apply(A::AbstractArray, scaling::MeanStdScaling; inverse=false, eps=1e-3, kwargs...) inverse && return scaling.μ .+ scaling.σ .* A diff --git a/test/scaling.jl b/test/scaling.jl index 7033413..afa762d 100644 --- a/test/scaling.jl +++ b/test/scaling.jl @@ -255,6 +255,26 @@ @test scaling.σ == 0.5 end end + + @testset "std correction" begin + @testset "singleton" begin + x = [2.] + + scaling = MeanStdScaling(x) + @test scaling.μ == 2. + @test isnan(scaling.σ) + + scaling = MeanStdScaling(x; corrected=false) + @test scaling.μ == 2. + @test scaling.σ == 0. + end + + @testset "Array" begin + scaling = MeanStdScaling(M; corrected=false) + @test scaling.μ == 0.5 + @test scaling.σ ≈ 0.81650 atol=1e-5 + end + end end @testset "Vector" begin