Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kNN Imputation #54

Merged
merged 1 commit into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ authors = ["Invenia Technical Computing"]
version = "0.4.0"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
rofinn marked this conversation as resolved.
Show resolved Hide resolved
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
18 changes: 18 additions & 0 deletions src/Impute.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module Impute

using Distances
using IterTools
using Missings
using NearestNeighbors
using Random
using Statistics
using StatsBase
Expand Down Expand Up @@ -67,6 +70,7 @@ const global imputation_methods = (
nocb = NOCB,
srs = SRS,
svd = SVD,
knn = KNN,
)

include("deprecated.jl")
Expand Down Expand Up @@ -333,4 +337,18 @@ Utility method for `impute(data, :svd; limit=limit)`
"""
svd(data::AbstractMatrix; limit=1.0) = impute(data, :svd; limit=limit)

"""
knn!(data::AbstractMatrix; limit=1.0)

Utility method for `impute!(data, :knn; limit=limit)`
"""
knn!(data::AbstractMatrix; limit=1.0) = impute!(data, :knn; limit=limit)

"""
knn(data::AbstractMatrix; limit=1.0)

Utility method for `impute(data, :knn; limit=limit)`
"""
knn(data::AbstractMatrix; limit=1.0) = impute(data, :knn; limit=limit)

end # module
2 changes: 1 addition & 1 deletion src/imputors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,6 @@ function impute!(table, imp::Imputor)
return table
end

for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl", "svd.jl")
for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl", "svd.jl", "knn.jl")
include(joinpath("imputors", file))
end
75 changes: 75 additions & 0 deletions src/imputors/knn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
KNN <: Imputor

Imputation using k-Nearest Neighbor algorithm.

# Keyword Arguments
* `k::Int`: number of nearest neighbors
* `dist::MinkowskiMetric`: distance metric suppports by `NearestNeighbors.jl` (Euclidean, Chebyshev, Minkowski and Cityblock)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might move the reference to NearestNeighbors.jl to the references section with a link?

* `threshold::AbsstractFloat`: thershold for missing neighbors
* `on_complete::Function`: a function to run when imputation is complete

# Reference
* Troyanskaya, Olga, et al. "Missing value estimation methods for DNA microarrays." Bioinformatics 17.6 (2001): 520-525.
"""
# TODO : Support Categorical Distance (NearestNeighbors.jl support needed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue for this that you can link to?

struct KNN{M} <: Imputor where M <: NearestNeighbors.MinkowskiMetric
k::Int
threshold::AbstractFloat
dist::M
context::AbstractContext
end

function KNN(; k=1, threshold=0.5, dist=Euclidean(), context=Context())
k < 1 && throw(ArgumentError("The number of nearset neighbors should be greater than 0"))

!(0 < threshold < 1) && throw(ArgumentError("Missing neighbors threshold should be within 0 to 1"))

# to exclude missing value itself
KNN(k + 1, threshold, dist, context)
end

function impute!(data::AbstractMatrix{<:Union{T, Missing}}, imp::KNN) where T<:Real
imp.context() do ctx
# Get mask array first (order of )
mmask = ismissing.(transpose(data))

# fill missing value as mean value
impute!(data, Fill(; value=mean, context=ctx))

# then, transpose to D x N for KDTree
transposed = transpose(disallowmissing(data))

kdtree = KDTree(transposed, imp.dist)
idxs, dists = NearestNeighbors.knn(kdtree, transposed, imp.k, true)

idxes = CartesianIndices(transposed)
fallback_threshold = imp.k * imp.threshold

for I in CartesianIndices(transposed)
if mmask[I] == 1
w = 1.0 ./ dists[I[2]]
ws = sum(w[2:end])
missing_neighbors = ismissing.(transposed[:, idxs[I[2]]][:, 2:end])

# exclude missing value itself because distance would be zero
if isnan(ws) || isinf(ws) || iszero(ws)
# if distance is zero or not a number, keep mean imputation
transposed[I] = transposed[I]
elseif count(!iszero, mapslices(sum, missing_neighbors, dims=1)) >
fallback_threshold
# If too many neighbors are also missing, fallback to mean imputation
# get column and check how many neighbors are also missing
transposed[I] = transposed[I]
else
# Inverse distance weighting
wt = w .* transposed[I[1], idxs[I[2]]]
transposed[I] = sum(wt[2:end]) / ws
end
end
end

# for type stability
allowmissing(transposed')
end
end
127 changes: 125 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

using AxisArrays
using Combinatorics
using DataFrames
Expand Down Expand Up @@ -30,7 +29,6 @@ using Impute:
interp,
chain


function add_missings(X, ratio=0.1)
appleparan marked this conversation as resolved.
Show resolved Hide resolved
result = Matrix{Union{Float64, Missing}}(X)

Expand All @@ -41,6 +39,17 @@ function add_missings(X, ratio=0.1)
return result
end

function add_missings_single(X, ratio=0.1)
result = Matrix{Union{Float64, Missing}}(X)

randcols = 1:floor(Int, size(X, 2) * ratio)
for col in randcols
result[rand(1:size(X, 1)), col] = missing
end

return result
end

@testset "Impute" begin
# Defining our missing datasets
a = allowmissing(1.0:1.0:20.0)
Expand Down Expand Up @@ -531,6 +540,120 @@ end
end
end

@testset "KNN" begin
appleparan marked this conversation as resolved.
Show resolved Hide resolved
@testset "Iris" begin
# Reference
# P. Schimitt, et. al
# A comparison of six methods for missing data imputation
iris = dataset("datasets", "iris")
iris2 = filter(row -> row[:Species] == "versicolor" || row[:Species] == "virginica", iris)
data = Array(iris2[:, [:SepalLength, :SepalWidth, :PetalLength, :PetalWidth]])
num_tests = 100

@testset "Iris - 0.15" begin
X = add_missings(data, 0.15)

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data, knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data, mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end

@testset "Iris - 0.25" begin
X = add_missings(data, 0.25)

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data, knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data, mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end

@testset "Iris - 0.35" begin
X = add_missings(data, 0.35)

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data, knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data, mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end
end

# Test a case where we expect kNN to perform well (e.g., many variables, )
@testset "Data match" begin
data = mapreduce(hcat, 1:1000) do i
seeds = [sin(i), cos(i), tan(i), atan(i)]
mapreduce(vcat, combinations(seeds)) do args
[
+(args...),
*(args...),
+(args...) * 100,
+(abs.(args)...),
(+(args...) * 10) ^ 2,
(+(abs.(args)...) * 10) ^ 2,
log(+(abs.(args)...) * 100),
+(args...) * 100 + rand(-10:0.1:10),
]
end
end

X = add_missings(data')
num_tests = 100

knn_nrmsd, mean_nrmsd = 0.0, 0.0

for i = 1:num_tests
knn_imputed = impute(copy(X), Impute.KNN(; k=2))
mean_imputed = impute(copy(X),
Fill(; value=mean, context=Context(; limit=1.0)))

knn_nrmsd = ((i - 1) * knn_nrmsd + nrmsd(data', knn_imputed)) / i
mean_nrmsd = ((i - 1) * mean_nrmsd + nrmsd(data', mean_imputed)) / i
end

@test knn_nrmsd < mean_nrmsd
# test type stability
@test typeof(X) == typeof(impute(copy(X), Impute.KNN(; k=2)))
@test typeof(X) == typeof(impute(copy(X), Fill(; value=mean,
context=Context(; limit=1.0))))
end
end

include("deprecated.jl")
include("testutils.jl")

Expand Down