Skip to content

Commit

Permalink
Expose corrected kwarg as alternative
Browse files Browse the repository at this point in the history
  • Loading branch information
bencottier committed May 12, 2021
1 parent e402c78 commit 5ce668d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
19 changes: 10 additions & 9 deletions src/scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct MeanStdScaling <: AbstractScaling
σ::Real

"""
MeanStdScaling(A::AbstractArray; dims=:, inds=:) -> MeanStdScaling
MeanStdScaling(table, [cols]) -> MeanStdScaling
MeanStdScaling(A::AbstractArray; dims=:, inds=:, corrected=true) -> MeanStdScaling
MeanStdScaling(table, [cols], corrected=true) -> MeanStdScaling
Construct a [`MeanStdScaling`](@ref) transform from the statistics of the given data.
By default _all the data_ is considered when computing the mean and standard deviation.
Expand All @@ -46,30 +46,31 @@ struct MeanStdScaling <: AbstractScaling
# `AbstractArray` keyword arguments
* `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims.
* `inds=:`: the indices to use in computing the statistics. Default uses all indices.
* `corrected=true`: passed to `Statistics.std`.
# `Table` keyword arguments
* `cols`: the columns to use in computing the statistics. Default uses all columns.
* `corrected=true`: passed to `Statistics.std`.
!!! note
If you want the `MeanStdScaling` to transform your data consistently you should use
the same `inds`, `dims`, or `cols` keywords when calling `apply`. Otherwise, `apply`
might rescale the wrong data or throw an error.
"""
function MeanStdScaling(A::AbstractArray; dims=:, inds=:)
dims == Colon() && return new(compute_stats(A)...)
return new(compute_stats(selectdim(A, dims, inds))...)
function MeanStdScaling(A::AbstractArray; dims=:, inds=:, corrected=true)
dims == Colon() && return new(compute_stats(A; corrected=corrected)...)
return new(compute_stats(selectdim(A, dims, inds); corrected=corrected)...)
end

function MeanStdScaling(table; cols=_get_cols(table))
function MeanStdScaling(table; cols=_get_cols(table), corrected=true)
Tables.istable(table) || throw(MethodError(MeanStdScaling, table))
columntable = Tables.columns(table)
data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)])
return new(compute_stats(data)...)
return new(compute_stats(data; corrected=corrected)...)
end
end

# Set std to 0 using corrected=false if x is a singleton
compute_stats(x) = (mean(x), std(x; corrected=(length(x) != 1)))
compute_stats(x; corrected) = (mean(x), std(x; corrected=corrected))

function _apply(A::AbstractArray, scaling::MeanStdScaling; inverse=false, eps=1e-3, kwargs...)
inverse && return scaling.μ .+ scaling.σ .* A
Expand Down
26 changes: 20 additions & 6 deletions test/scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,26 @@
@test scaling.σ == 0.5
end
end

@testset "std correction" begin
@testset "singleton" begin
x = [2.]

scaling = MeanStdScaling(x)
@test scaling.μ == 2.
@test isnan(scaling.σ)

scaling = MeanStdScaling(x; corrected=false)
@test scaling.μ == 2.
@test scaling.σ == 0.
end

@testset "Array" begin
scaling = MeanStdScaling(M; corrected=false)
@test scaling.μ == 0.5
@test scaling.σ 0.81650 atol=1e-5
end
end
end

@testset "Vector" begin
Expand Down Expand Up @@ -312,12 +332,6 @@
scaling = MeanStdScaling(x)
@test FeatureTransforms.apply_append(x, scaling, append_dim=1) == vcat(x, expected)
end

@testset "singleton" begin
x = [2.]
scaling = MeanStdScaling(x)
@test FeatureTransforms.apply(x, scaling) == [0.]
end
end

@testset "Matrix" begin
Expand Down

0 comments on commit 5ce668d

Please sign in to comment.