Skip to content

Commit

Permalink
Merge pull request #87 from invenia/rf/declaremissings-not-imputor
Browse files Browse the repository at this point in the history
DeclareMissings is not a real Imputor
  • Loading branch information
rofinn authored Nov 2, 2020
2 parents 2de3f7c + b276323 commit 3ef5852
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 83 deletions.
1 change: 1 addition & 0 deletions src/Impute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using LinearAlgebra
using LinearAlgebra: Diagonal

include("utils.jl")
include("declaremissings.jl")
include("imputors.jl")
include("filter.jl")
include("validators.jl")
Expand Down
6 changes: 5 additions & 1 deletion src/chain.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const Transform = Union{Validator, Filter, Imputor}
# This should maybe be a supertype at some point?
const Transform = Union{Validator, DeclareMissings, Filter, Imputor}

"""
Chain{T<:Tuple{Vararg{Transform}}} <: Function
Expand Down Expand Up @@ -69,6 +70,9 @@ function (C::Chain)(data; kwargs...)
if isa(t, Validator)
# Validators just return the input
validate(X, t; kwargs...)
elseif isa(t, DeclareMissings)
# DeclareMissings isn't guaranteed to work in-place
X = apply(X, t)
elseif isa(t, Filter)
# Filtering doesn't always work in-place
X = apply(X, t; kwargs...)
Expand Down
31 changes: 12 additions & 19 deletions src/imputors/declaremissings.jl → src/declaremissings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ a `missing`.
# Example
```jldoctest
julia> using Impute: DeclareMissings, impute
julia> using Impute: DeclareMissings, apply
julia> M = [1.0 2.0 -9999.0 NaN 5.0; 1.1 2.2 3.3 0.0 5.5]
2×5 Array{Float64,2}:
1.0 2.0 -9999.0 NaN 5.0
1.1 2.2 3.3 0.0 5.5
julia> impute(M, DeclareMissings(; values=(NaN, -9999.0, 0.0)))
julia> apply(M, DeclareMissings(; values=(NaN, -9999.0, 0.0)))
2×5 Array{Union{Missing, Float64},2}:
1.0 2.0 missing missing 5.0
1.1 2.2 3.3 missing 5.5
```
"""
struct DeclareMissings{T<:Tuple} <: Imputor
struct DeclareMissings{T<:Tuple}
values::T
end

Expand All @@ -34,25 +34,27 @@ function DeclareMissings(; values)
return DeclareMissings{typeof(T)}(T)
end

apply!(data::AbstractArray{Missing}, imp::DeclareMissings) = data

# Primary definition just calls `replace!`
function _impute!(data::AbstractArray{Union{T, Missing}}, imp::DeclareMissings) where T
function apply!(data::AbstractArray{Union{T, Missing}}, imp::DeclareMissings) where T
# Reduce the possible set of values to those that could actually be found in the data
# Useful, if we declare a `Replace` imputor that should be applied to multiple datasets.
Base.replace!(data, (v => missing for v in imp.values if v isa T)...)
end

# Most of the time the in-place methods won't work because we need to change the
# eltype with allowmissing
impute(data::AbstractArray, imp::DeclareMissings) = _impute!(allowmissing(data), imp)
apply(data::AbstractArray, imp::DeclareMissings) = apply!(allowmissing(data), imp)

# Custom implementation of a non-mutating impute for tables
function impute(table, imp::DeclareMissings)
istable(table) || throw(MethodError(impute, (table, imp)))
function apply(table, imp::DeclareMissings)
istable(table) || throw(MethodError(apply, (table, imp)))

ctable = Tables.columns(table)

cnames = Tuple(propertynames(ctable))
cdata = Tuple(impute(getproperty(ctable, cname), imp) for cname in cnames)
cdata = Tuple(apply(getproperty(ctable, cname), imp) for cname in cnames)
# Reconstruct as a ColumnTable
result = NamedTuple{cnames}(cdata)

Expand All @@ -65,16 +67,7 @@ function impute(table, imp::DeclareMissings)
end

# Specialcase for rowtable
function impute(data::T, imp::DeclareMissings) where T <: AbstractVector{<:NamedTuple}
function apply(data::T, imp::DeclareMissings) where T <: AbstractVector{<:NamedTuple}
# We use columntable here so that we don't call `materialize` more often than needed.
return materializer(data)(impute(Tables.columntable(data), imp))
end

# Awkward imputor overrides necessary because we intercepted the higher level
# `impute` calls
_impute!(data::AbstractArray{Missing}, imp::DeclareMissings) = data

# Skip custom dims stuff cause it isn't necessary here.
function impute!(data::AbstractMatrix{Union{T, Missing}}, imp::DeclareMissings) where {T}
return _impute!(data, imp)
return materializer(data)(apply(Tables.columntable(data), imp))
end
4 changes: 3 additions & 1 deletion src/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ const global imputation_methods = (
nocb = NOCB,
replace = Replace,
srs = SRS,
declaremissings = DeclareMissings,
substitute = Substitute,
svd = SVD,
knn = KNN,
Expand Down Expand Up @@ -83,6 +82,9 @@ for (func, type) in pairs(imputation_methods)
end
end

declaremissings(data; kwargs...) = apply(data, DeclareMissings(; kwargs...))
declaremissings!(data; kwargs...) = apply!(data, DeclareMissings(; kwargs...))

# Provide a specific functional API for Impute.Filter.
filter(data; kwargs...) = apply(data, Filter(); kwargs...)
filter!(data; kwargs...) = apply!(data, Filter(); kwargs...)
Expand Down
1 change: 0 additions & 1 deletion src/imputors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ files = [
"nocb.jl",
"replace.jl",
"srs.jl",
"declaremissings.jl",
"substitute.jl",
"svd.jl",
]
Expand Down
1 change: 1 addition & 0 deletions test/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
# Filter out colunns with more than 400 missing values, Fill with 0, and check that
# everything was replaced
C = Chain(
Impute.DeclareMissings(; values=(NaN, Inf, -Inf)),
Impute.Filter(c -> count(ismissing, c) < 400),
Impute.Replace(; values=0.0),
Impute.Threshold(),
Expand Down
82 changes: 22 additions & 60 deletions test/imputors/declaremissings.jl → test/declaremissings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,39 @@
a = collect(1.0:1.0:20.0)
a[[2, 3, 7]] .= [NaN, 0.0, NaN]

result = impute(a, imp)
result = apply(a, imp)
@test eltype(result) == Union{Float64, Missing}
@test all(ismissing, result[[2, 3, 7]])

# In-place operation don't work when the source array doesn't allow missings.
b = copy(a)
result2 = impute!(b, imp)
@test eltype(result2) == Float64
@test isequal(result2[[2, 3, 7]], [NaN, 0.0, NaN])
@test_throws MethodError apply!(b, imp)
end

@testset "allowmissing" begin
a = allowmissing(collect(1.0:1.0:20.0))
a[[2, 3, 7]] .= [NaN, 0.0, NaN]

result = impute(a, imp)
result = apply(a, imp)
@test eltype(result) == Union{Float64, Missing}
@test all(ismissing, result[[2, 3, 7]])

# In-place operation don't work when the source array doesn't allow missings.
# In-place operation work when the source array allows missings.
b = copy(a)
result2 = impute!(b, imp)
result2 = apply!(b, imp)
@test eltype(result2) == Union{Float64, Missing}
@test all(ismissing, result2[[2, 3, 7]])

c = copy(a)
result3 = Impute.declaremissings!(c; values=values)
@test eltype(result3) == Union{Float64, Missing}
@test all(ismissing, result3[[2, 3, 7]])
end

@testset "All missing" begin
# Test having only missing data
c = fill(missing, 10)
@test isequal(impute(c, imp), c)
@test isequal(apply(c, imp), c)
end
end

Expand All @@ -49,37 +52,35 @@
a[[2, 3, 7]] .= [NaN, 0.0, NaN]
m = collect(reshape(a, 5, 4))

result = impute(m, imp)
result = apply(m, imp)
@test eltype(result) == Union{Float64, Missing}
@test all(ismissing, result[[2, 3, 7]])

# In-place operation don't work when the source array doesn't allow missings.
n = copy(m)
result2 = impute!(n, imp)
@test eltype(result2) == Float64
@test isequal(result2[[2, 3, 7]], [NaN, 0.0, NaN])
@test_throws MethodError apply!(n, imp)
end

@testset "allowmissing" begin
a = allowmissing(collect(1.0:1.0:20.0))
a[[2, 3, 7]] .= [NaN, 0.0, NaN]
m = collect(reshape(a, 5, 4))

result = impute(m, imp)
result = apply(m, imp)
@test eltype(result) == Union{Float64, Missing}
@test all(ismissing, result[[2, 3, 7]])

# In-place operation don't work when the source array doesn't allow missings.
n = copy(m)
result2 = impute!(n, imp)
result2 = apply!(n, imp)
@test eltype(result2) == Union{Float64, Missing}
@test all(ismissing, result2[[2, 3, 7]])
end

@testset "All missing" begin
# Test having only missing data
c = fill(missing, 5, 4)
@test isequal(impute(c, imp), c)
@test isequal(apply(c, imp), c)
end
end
@testset "Tables" begin
Expand All @@ -106,21 +107,8 @@
:desc => ["foo", "bar", missing],
)

@testset "disallowmissing" begin
result = impute(table, imp)
@test isequal(result, expected)

result2 = impute!(deepcopy(table), imp)
@test !isequal(result2, expected)
end

@testset "allowmissing" begin
result = impute(mtable, imp)
@test isequal(result, expected)

result2 = impute!(deepcopy(mtable), imp)
@test isequal(result2, expected)
end
result = apply(table, imp)
@test isequal(result, expected)
end

@testset "Column Table" begin
Expand All @@ -146,21 +134,8 @@
desc = ["foo", "bar", missing],
)

@testset "disallowmissing" begin
result = impute(table, imp)
@test isequal(result, expected)

result2 = impute!(deepcopy(table), imp)
@test !isequal(result2, expected)
end

@testset "allowmissing" begin
result = impute(mtable, imp)
@test isequal(result, expected)

result2 = impute!(deepcopy(mtable), imp)
@test isequal(result2, expected)
end
result = apply(table, imp)
@test isequal(result, expected)
end

@testset "Row Table" begin
Expand All @@ -186,21 +161,8 @@
desc = ["foo", "bar", missing],
))

@testset "disallowmissing" begin
result = impute(table, imp)
@test isequal(result, expected)

result2 = impute!(deepcopy(table), imp)
@test !isequal(result2, expected)
end

@testset "allowmissing" begin
result = impute(mtable, imp)
@test isequal(result, expected)

result2 = impute!(deepcopy(mtable), imp)
@test isequal(result2, expected)
end
result = apply(table, imp)
@test isequal(result, expected)
end
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using Impute:
Threshold,
ThresholdError,
apply,
apply!,
impute,
impute!,
interp,
Expand All @@ -48,6 +49,7 @@ using Impute:
include("testutils.jl")

include("validators.jl")
include("declaremissings.jl")
include("chain.jl")
include("data.jl")
include("deprecated.jl")
Expand All @@ -58,7 +60,6 @@ using Impute:
include("imputors/nocb.jl")
include("imputors/replace.jl")
include("imputors/srs.jl")
include("imputors/declaremissings.jl")
include("imputors/substitute.jl")
include("imputors/svd.jl")
include("utils.jl")
Expand Down

0 comments on commit 3ef5852

Please sign in to comment.