Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StandardScaling with separate constructor and fit methods to replace MeanStdScaling #107

Merged
merged 16 commits into from
May 17, 2022
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "FeatureTransforms"
uuid = "8fd68953-04b8-4117-ac19-158bf6de9782"
authors = ["Invenia Technical Computing Corporation"]
version = "0.3.11"
version = "0.3.12"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Expand All @@ -16,6 +17,7 @@ AxisKeys = "0.1"
DataFrames = "0.22, 1"
Documenter = "0.26"
NamedDims = "0.2.32"
StatsBase = "0.33"
Tables = "1.3"
julia = "1.3"

Expand Down
6 changes: 5 additions & 1 deletion src/FeatureTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ module FeatureTransforms
using Dates: TimeType, Period, Day, hour
using NamedDims: dim
using Statistics: mean, std
using StatsBase
using Tables

export Transform, transform, transform!
export HoD, LinearCombination, OneHotEncoding, Periodic, Power
export AbstractScaling, IdentityScaling, MeanStdScaling
export AbstractScaling, IdentityScaling, MeanStdScaling, StandardScaling
export LogTransform, InverseHyperbolicSine

include("utils.jl")
include("traits.jl")
include("transform.jl")
include("apply.jl")
include("fit.jl")

# Transform implementations
include("linear_combination.jl")
Expand All @@ -26,6 +28,8 @@ include("temporal.jl")

include("test_utils.jl")

include("deprecated.jl")

# TODO: remove in v0.4 https://github.com/invenia/FeatureTransforms.jl/issues/82
Base.@deprecate_binding is_transformable TestUtils.is_transformable
mzgubic marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
73 changes: 73 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
MeanStdScaling(μ, σ) <: AbstractScaling

Linearly scale the data by the statistical mean `μ` and standard deviation `σ`.
This is also known as standardization, or the Z score transform.

# Keyword arguments to `apply`
* `inverse=true`: inverts the scaling (e.g. to reconstruct the unscaled data).
* `eps=1e-3`: used in place of all 0 values in `σ` before scaling (if `inverse=false`).
"""
struct MeanStdScaling <: AbstractScaling
μ::Real
σ::Real

"""
MeanStdScaling(A::AbstractArray; dims=:, inds=:) -> MeanStdScaling
MeanStdScaling(table, [cols]) -> MeanStdScaling

Construct a [`MeanStdScaling`](@ref) transform from the statistics of the given data.
By default _all the data_ is considered when computing the mean and standard deviation.
This can be restricted to certain slices via the keyword arguments (see below).

Since `MeanStdScaling` is a stateful transform, i.e. the parameters depend on the data
it's given, you should define it independently before applying it so you can keep the
information for later use. For instance, if you want to invert the transform or apply it
to a test set.

# `AbstractArray` keyword arguments
* `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims.
* `inds=:`: the indices to use in computing the statistics. Default uses all indices.

# `Table` keyword arguments
* `cols`: the columns to use in computing the statistics. Default uses all columns.

!!! note
If you want the `MeanStdScaling` to transform your data consistently you should use
the same `inds`, `dims`, or `cols` keywords when calling `apply`. Otherwise, `apply`
might rescale the wrong data or throw an error.
"""
function MeanStdScaling(A::AbstractArray; dims=:, inds=:)
_depwarn()
dims == Colon() && return new(compute_stats(A)...)
return new(compute_stats(selectdim(A, dims, inds))...)
end

function MeanStdScaling(table; cols=_get_cols(table))
_depwarn()
Tables.istable(table) || throw(MethodError(MeanStdScaling, table))
columntable = Tables.columns(table)
data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)])
return new(compute_stats(data)...)
end
end

function _depwarn()
Base.depwarn(
"`MeanStdScaling(args...; kwargs...)` is deprecated. Use " *
"`ss = StandardScaling(); fit!(scaling, args...; kwargs...)` instead",
:MeanStdScaling
)
return nothing
end

compute_stats(x) = (mean(x), std(x))

function _apply(A::AbstractArray, scaling::MeanStdScaling; inverse=false, eps=1e-3, kwargs...)
inverse && return scaling.μ .+ scaling.σ .* A
# Avoid division by 0
# If std is 0 then data was uniform, so the scaled value would end up ≈ 0
# Therefore the particular `eps` value should not matter much.
σ_safe = max(scaling.σ, eps)
return (A .- scaling.μ) ./ σ_safe
end
1 change: 1 addition & 0 deletions src/fit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
StatsBase.fit!(t::Transform, args...; kwargs...) = return t
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
103 changes: 56 additions & 47 deletions src/scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,73 @@ IdentityScaling(args...) = IdentityScaling()
@inline _apply(x, ::IdentityScaling; kwargs...) = x

"""
MeanStdScaling(μ, σ) <: AbstractScaling
StandardScaling <: AbstractScaling

Linearly scale the data by the statistical mean `μ` and standard deviation `σ`.
This is also known as standardization, or the Z score transform.
Transforms the data according to

# Keyword arguments to `apply`
* `inverse=true`: inverts the scaling (e.g. to reconstruct the unscaled data).
* `eps=1e-3`: used in place of all 0 values in `σ` before scaling (if `inverse=false`).
x -> (x - μ) / σ

where μ and σ are the mean and standard deviation of the training data.

!!! note
`fit!(scaling, data)` needs to be called before the transform can be `apply`ed.
By default _all the data_ is considered when `fit!`ing the mean and standard deviation.
"""
struct MeanStdScaling <: AbstractScaling
mutable struct StandardScaling <: AbstractScaling
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
μ::Real
σ::Real
fitted::Bool
mzgubic marked this conversation as resolved.
Show resolved Hide resolved

"""
MeanStdScaling(A::AbstractArray; dims=:, inds=:) -> MeanStdScaling
MeanStdScaling(table, [cols]) -> MeanStdScaling

Construct a [`MeanStdScaling`](@ref) transform from the statistics of the given data.
By default _all the data_ is considered when computing the mean and standard deviation.
This can be restricted to certain slices via the keyword arguments (see below).

Since `MeanStdScaling` is a stateful transform, i.e. the parameters depend on the data
it's given, you should define it independently before applying it so you can keep the
information for later use. For instance, if you want to invert the transform or apply it
to a test set.

# `AbstractArray` keyword arguments
* `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims.
* `inds=:`: the indices to use in computing the statistics. Default uses all indices.

# `Table` keyword arguments
* `cols`: the columns to use in computing the statistics. Default uses all columns.

!!! note
If you want the `MeanStdScaling` to transform your data consistently you should use
the same `inds`, `dims`, or `cols` keywords when calling `apply`. Otherwise, `apply`
might rescale the wrong data or throw an error.
"""
function MeanStdScaling(A::AbstractArray; dims=:, inds=:)
dims == Colon() && return new(compute_stats(A)...)
return new(compute_stats(selectdim(A, dims, inds))...)
end
StandardScaling() = return new(0.0, 1.0, false)
end

function MeanStdScaling(table; cols=_get_cols(table))
Tables.istable(table) || throw(MethodError(MeanStdScaling, table))
columntable = Tables.columns(table)
data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)])
return new(compute_stats(data)...)
end
"""
fit!(scaling::StandardScaling, data::AbstractArray; dims=:, inds=:)
fit!(scaling::StandardScaling, table, [cols])

Fit the [`StandardScaling`](@ref) transform to the given data. By default _all the data_
is considered when computing the mean and standard deviation.
This can be restricted to certain slices via the keyword arguments (see below).

# `AbstractArray` keyword arguments
* `dims=:`: the dimension along which to take the `inds` slices. Default uses all dims.
* `inds=:`: the indices to use in computing the statistics. Default uses all indices.

# `Table` keyword arguments
* `cols`: the columns to use in computing the statistics. Default uses all columns.

!!! note
If you want the `StandardScaling` to transform your data consistently you should use
the same `inds`, `dims`, or `cols` keywords when calling `apply`. Otherwise, `apply`
might rescale the wrong data or throw an error.
"""
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
function StatsBase.fit!(ss::StandardScaling, args...; kwargs...)
ss.fitted === true && @warn("StandardScaling is being refit, Y?")
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
μ, σ = _fit(ss, args...; kwargs...)
ss.μ, ss.σ, ss.fitted = μ, σ, true
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
return ss
end

compute_stats(x) = (mean(x), std(x))
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
function _fit(::StandardScaling, data::AbstractArray; dims=:, inds=:)
return if dims isa Colon
compute_stats(data)
else
compute_stats(selectdim(data, dims, inds))
end
mzgubic marked this conversation as resolved.
Show resolved Hide resolved
end
function _fit(::StandardScaling, table; cols=_get_cols(table))
Tables.istable(table) || throw(MethodError(StandardScaling, table))
columntable = Tables.columns(table)
data = reduce(vcat, [getproperty(columntable, c) for c in _to_vec(cols)])
return compute_stats(data)
end

function _apply(A::AbstractArray, scaling::MeanStdScaling; inverse=false, eps=1e-3, kwargs...)
inverse && return scaling.μ .+ scaling.σ .* A
function _apply(A::AbstractArray, ss::StandardScaling; inverse=false, eps=1e-3, kwargs...)
ss.fitted === true || throw(ErrorException("`fit!` StandardScaling before applying."))
inverse && return ss.μ .+ ss.σ .* A
# Avoid division by 0
# If std is 0 then data was uniform, so the scaled value would end up ≈ 0
# Therefore the particular `eps` value should not matter much.
σ_safe = max(scaling.σ, eps)
return (A .- scaling.μ) ./ σ_safe
σ_safe = max(ss.σ, eps)
return (A .- ss.μ) ./ σ_safe
end
65 changes: 65 additions & 0 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@testset "deprecated.jl" begin

@testset "MeanStdScaling" begin
@testset "Constructor" begin
M = [0.0 -0.5 0.5; 0.0 1.0 2.0]
nt = (a = [0.0, -0.5, 0.5], b = [1.0, 0.0, 2.0])

@testset "simple" for x in (M, nt)
x_copy = deepcopy(x)
@test_deprecated scaling = MeanStdScaling(x)
@test cardinality(scaling) == OneToOne()
@test scaling isa Transform
@test x == x_copy # data is not mutated
# constructor uses all data by default
@test scaling.μ == 0.5
@test scaling.σ ≈ 0.89443 atol=1e-5
end

@testset "use certain slices to compute statistics" begin
@testset "Array" begin
scaling = MeanStdScaling(M; dims=1, inds=[2])
@test scaling.μ == 1.0
@test scaling.σ == 1.0
end

@testset "Table" begin
scaling = MeanStdScaling(nt; cols=:a)
@test scaling.μ == 0.0
@test scaling.σ == 0.5
end
end
end

@testset "Re-apply" begin
M = [0.0 -0.5 0.5; 0.0 1.0 2.0]
scaling = MeanStdScaling(M; dims=2)
new_M = [1.0 -2.0 -1.0; 0.5 0.0 0.5]
@test M !== new_M
# Expect scaling parameters to be fixed to the first data applied to
expected_reapply = [0.559017 -2.79508 -1.67705; 0.0 -0.55901 0.0]
@test FeatureTransforms.apply(new_M, scaling; dims=2) ≈ expected_reapply atol=1e-5
end

@testset "Inverse" begin
M = [0.0 -0.5 0.5; 0.0 1.0 2.0]
M_expected = [-0.559017 -1.11803 0.0; -0.559017 0.559017 1.67705]
scaling = MeanStdScaling(M)
transformed = FeatureTransforms.apply(M, scaling)

@test transformed ≈ M_expected atol=1e-5
@test FeatureTransforms.apply(transformed, scaling; inverse=true) ≈ M atol=1e-5
end

@testset "Zero std" begin
x = ones(Float64, 3)
expected = zeros(Float64, 3)

scaling = MeanStdScaling(x)

@test FeatureTransforms.apply(x, scaling) == expected # default `eps`
@test FeatureTransforms.apply(x, scaling; eps=1) == expected
@test all(isnan.(FeatureTransforms.apply(x, scaling; eps=0))) # 0/0
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using FeatureTransforms
using FeatureTransforms: _periodic
using FeatureTransforms: cardinality, OneToOne, OneToMany, ManyToOne, ManyToMany
using FeatureTransforms.TestUtils
using StatsBase
using Tables: columntable, isrowtable, istable, rowtable
using Test
using TimeZones
Expand All @@ -31,4 +32,6 @@ using TimeZones
include("types/matrix.jl")
include("types/cube.jl")
include("types/xarray.jl")

include("deprecated.jl")
end
Loading