Skip to content

Commit

Permalink
Simple implementation of KNN imputation
Browse files Browse the repository at this point in the history
* inspired by SVD imputation (invenia#16)
  • Loading branch information
appleparan committed Feb 18, 2020
1 parent a7c6cec commit 2d09bb9
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 5 deletions.
152 changes: 152 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# This file is machine-generated - editing it directly is not advised

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[Combinatorics]]
deps = ["Polynomials"]
git-tree-sha1 = "140cc833258df8e5aafabdb068875ac0c256be23"
uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
version = "1.0.0"

[[DataAPI]]
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.1.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.9"

[[DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
version = "1.0.0"

[[Distances]]
deps = ["LinearAlgebra", "Statistics"]
git-tree-sha1 = "23717536c81b63e250f682b0e0933769eecd1411"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.8.2"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[IterTools]]
git-tree-sha1 = "05110a2ab1fc5f932622ffea2a003221f4782c18"
uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
version = "1.3.0"

[[IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.3"

[[NearestNeighbors]]
deps = ["Distances", "StaticArrays"]
git-tree-sha1 = "8bc6180f328f3c0ea2663935db880d34c57d6eae"
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
version = "0.4.4"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"

[[Polynomials]]
deps = ["LinearAlgebra", "RecipesBase"]
git-tree-sha1 = "ae71c2329790af97b7682b11241b3609e4d48626"
uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
version = "0.6.0"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[RecipesBase]]
git-tree-sha1 = "7bdce29bc9b2f5660a6e5e64d64d91ec941f6aa2"
uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
version = "0.7.0"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.12.1"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "be5c7d45daa449d12868f4466dbf5882242cf2d9"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.32.1"

[[TableTraits]]
deps = ["IteratorInterfaceExtensions"]
git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e"
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
version = "1.0.0"

[[Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"]
git-tree-sha1 = "aaed7b3b00248ff6a794375ad6adf30f30ca5591"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "0.2.11"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ authors = ["Invenia Technical Computing"]
version = "0.4.0"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -20,10 +22,11 @@ julia = "1"

[extras]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AxisArrays", "DataFrames", "Dates", "RDatasets", "Test"]
test = ["AxisArrays", "Combinatorics", "DataFrames", "Dates", "RDatasets", "Test"]
3 changes: 3 additions & 0 deletions src/Impute.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module Impute

using Distances
using IterTools
using NearestNeighbors
using Random
using Statistics
using StatsBase
Expand Down Expand Up @@ -63,6 +65,7 @@ const global imputation_methods = (
locf = LOCF,
nocb = NOCB,
srs = SRS,
knn = KNN,
)

include("deprecated.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/imputors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,6 @@ function impute!(table, imp::Imputor)
end


for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl")
for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl", "knn.jl")
include(joinpath("imputors", file))
end
74 changes: 74 additions & 0 deletions src/imputors/knn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
KNN <: Imputor
Imputation using k-Nearest Neighbor algorithm.
# Reference
* Troyanskaya, Olga, et al. "Missing value estimation methods for DNA microarrays." Bioinformatics 17.6 (2001): 520-525.
"""
# TODO : Support Categorical Distance (NearestNeighbors.jl support needed)
struct KNN{M} <: Imputor where M <: NearestNeighbors.MinkowskiMetric
num_nn::Int
dist::M
context::AbstractContext
end

function KNN(; num_nn=1,
dist=Euclidean(), context=Context())

KNN(num_nn, dist, context)
end

"""
impute!(imp::KNN, data::AbstractMatrix)
data : N x D matrix
"""
function impute!(data::AbstractMatrix{<:Union{T, Missing}},
imp::KNN) where T<:Real

imp.context() do ctx
# Get our before and after views of our missing and non-missing data
mmask = ismissing.(data)
omask = .!mmask

mdata = data[mmask]
odata = data[omask]

# Fill in the original data as mean value
# TODO : pass Fill in constructor
impute!(data, Fill(; value=mean, context=ctx))

# transpose to D x N for KDTree
dataT = float.(collect(transpose(data)))
kdtree = KDTree(dataT, imp.dist)

# index of columns
idxs, dists = NearestNeighbors.knn(kdtree, dataT, imp.num_nn, true)
invWdist(i) = dists[i] == 0 ? dataT[idxs[i]] : dists[i]
# TODO : going to parallel?
for (i, x) in enumerate(mdata)
if ndims(dataT) != 1
# ndims(dataT) == 1 means there is there is only single row.
if mmask[i] == 1
# idxs[i] = k-length column indicies
# weighted sum

# what about mulitple missings??
coli = div(i, size(dataT, 1)) + 1
# println("$i $coli $(size(mdata))")
w = sum(dists[coli]) ./ dists[coli]
# Inverse distance weighting
if isnan(sum(w)) || isinf(sum(w))
# if some distance is zero, keep original value
data[i] = data[i]
else
data[i] = sum(w .* dataT[idxs[coli]])
end
end
end
end

return data
end
end
47 changes: 44 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using Impute
using Tables
using Test
using AxisArrays
using Combinatorics
using DataFrames
using Dates
using Distances
using Random
using RDatasets
using Statistics
using StatsBase
using Random
using Tables
using Test

using Impute:
Drop,
Expand All @@ -24,6 +26,15 @@ using Impute:
interp,
chain

function add_missings(X, ratio=0.1)
result = Matrix{Union{Float64, Missing}}(X)

for i in 1:floor(Int, length(X) * ratio)
result[rand(1:length(X))] = missing
end

return result
end

@testset "Impute" begin
# Defining our missing datasets
Expand Down Expand Up @@ -515,5 +526,35 @@ using Impute:
end
end

@testset "KNN" begin
# Test a case with few variable
# (e.g., only a few variables, only )
@testset "Data - few variables" begin
data = Matrix(dataset("Ecdat", "Electricity"))
X = add_missings(data)

knn_imputed = impute(copy(X), Impute.KNN(; num_nn=3, dist=Euclidean(), context=Context(; limit = 1.0)))
mean_imputed = impute(copy(X), Fill(; context=Context(; limit=1.0)))

# If we don't have enough variables then SVD imputation will probably perform
# about as well as mean imputation.
@test nrmsd(knn_imputed, data) > nrmsd(mean_imputed, data) * 0.9
end

@testset "Data - random variables" begin
M = rand(100, 200)
data = M * M'
X = add_missings(data)

knn_imputed = impute(copy(X), Impute.KNN(; num_nn=3, dist=Euclidean(), context=Context(; limit = 1.0)))
mean_imputed = impute(copy(X), Fill(; context=Context(; limit=1.0)))

# If most of the variance in the original data can't be explained by a small
# subset of the eigen values in the svd decomposition then our low rank approximations
# won't perform very well.
@test nrmsd(knn_imputed, data) > nrmsd(mean_imputed, data) * 0.9
end
end

include("deprecated.jl")
end

0 comments on commit 2d09bb9

Please sign in to comment.