From 613226189ccb4e57fa06c62a64886fb4cb10ff2f Mon Sep 17 00:00:00 2001 From: Liam Jongsu Kim Date: Wed, 19 Feb 2020 04:45:29 +0900 Subject: [PATCH] Simple implementation of KNN imputation * inspired by SVD imputation (#16) --- Manifest.toml | 152 ++++++++++++++++++++++++++++++++++++++++++++ Project.toml | 5 +- src/Impute.jl | 3 + src/imputors.jl | 2 +- src/imputors/knn.jl | 79 +++++++++++++++++++++++ test/runtests.jl | 47 +++++++++++++- 6 files changed, 283 insertions(+), 5 deletions(-) create mode 100644 Manifest.toml create mode 100644 src/imputors/knn.jl diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..e5c2b36 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,152 @@ +# This file is machine-generated - editing it directly is not advised + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[Combinatorics]] +deps = ["Polynomials"] +git-tree-sha1 = "140cc833258df8e5aafabdb068875ac0c256be23" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.0" + +[[DataAPI]] +git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.1.0" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.17.9" + +[[DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[Distances]] +deps = ["LinearAlgebra", "Statistics"] +git-tree-sha1 = "23717536c81b63e250f682b0e0933769eecd1411" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.8.2" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[IterTools]] +git-tree-sha1 = "05110a2ab1fc5f932622ffea2a003221f4782c18" +uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +version = "1.3.0" + +[[IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.3" + +[[NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "8bc6180f328f3c0ea2663935db880d34c57d6eae" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.4" + +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.1.0" + +[[Polynomials]] +deps = ["LinearAlgebra", "RecipesBase"] +git-tree-sha1 = "ae71c2329790af97b7682b11241b3609e4d48626" +uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" +version = "0.6.0" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[RecipesBase]] +git-tree-sha1 = "7bdce29bc9b2f5660a6e5e64d64d91ec941f6aa2" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "0.7.0" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.12.1" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "be5c7d45daa449d12868f4466dbf5882242cf2d9" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.32.1" + +[[TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.0" + +[[Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] +git-tree-sha1 = "aaed7b3b00248ff6a794375ad6adf30f30ca5591" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "0.2.11" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/Project.toml b/Project.toml index 8eb142d..c8e8f3d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,9 @@ authors = ["Invenia Technical Computing"] version = "0.4.0" [deps] +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -20,10 +22,11 @@ julia = "1" [extras] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AxisArrays", "DataFrames", "Dates", "RDatasets", "Test"] +test = ["AxisArrays", "Combinatorics", "DataFrames", "Dates", "RDatasets", "Test"] diff --git a/src/Impute.jl b/src/Impute.jl index a113f4b..b594db0 100644 --- a/src/Impute.jl +++ b/src/Impute.jl @@ -1,6 +1,8 @@ module Impute +using Distances using IterTools +using NearestNeighbors using Random using Statistics using StatsBase @@ -63,6 +65,7 @@ const global imputation_methods = ( locf = LOCF, nocb = NOCB, srs = SRS, + knn = KNN, ) include("deprecated.jl") diff --git a/src/imputors.jl b/src/imputors.jl index 92e64f7..48e298f 100644 --- a/src/imputors.jl +++ b/src/imputors.jl @@ -167,6 +167,6 @@ function impute!(table, imp::Imputor) end -for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl") +for file in ("drop.jl", "locf.jl", "nocb.jl", "interp.jl", "fill.jl", "chain.jl", "srs.jl", "knn.jl") include(joinpath("imputors", file)) end diff --git a/src/imputors/knn.jl b/src/imputors/knn.jl new file mode 100644 index 0000000..22317b1 --- /dev/null +++ b/src/imputors/knn.jl @@ -0,0 +1,79 @@ +""" + KNN <: Imputor + +Imputation using k-Nearest Neighbor algorithm. + +# Keyword Arguments +* `num_nn::Int`: number of nearest neighbors +* `dist::MinkowskiMetric`: distance metric suppports by `NearestNeighbors.jl` (Euclidean, Chebyshev, Minkowski and Cityblock) +* `on_complete::Function`: a function to run when imputation is complete + +# Reference +* Troyanskaya, Olga, et al. "Missing value estimation methods for DNA microarrays." Bioinformatics 17.6 (2001): 520-525. +""" +# TODO : Support Categorical Distance (NearestNeighbors.jl support needed) +struct KNN{M} <: Imputor where M <: NearestNeighbors.MinkowskiMetric + num_nn::Int + dist::M + context::AbstractContext +end + +function KNN(; num_nn=1, + dist=Euclidean(), context=Context()) + + KNN(num_nn, dist, context) +end + +""" + impute!(imp::KNN, data::AbstractMatrix) + +data : N x D matrix +""" +function impute!(data::AbstractMatrix{<:Union{T, Missing}}, + imp::KNN) where T<:Real + + imp.context() do ctx + # Get our before and after views of our missing and non-missing data + mmask = ismissing.(data) + omask = .!mmask + + mdata = data[mmask] + odata = data[omask] + + # Fill in the original data as mean value + # TODO : pass Fill in constructor + impute!(data, Fill(; value=mean, context=ctx)) + + # transpose to D x N for KDTree + dataT = float.(collect(transpose(data))) + kdtree = KDTree(dataT, imp.dist) + + # index of columns + idxs, dists = NearestNeighbors.knn(kdtree, dataT, imp.num_nn, true) + invWdist(i) = dists[i] == 0 ? dataT[idxs[i]] : dists[i] + # TODO : going to parallel? + for (i, x) in enumerate(mdata) + if ndims(dataT) != 1 + # ndims(dataT) == 1 means there is there is only single row. + if mmask[i] == 1 + # idxs[i] = k-length column indicies + # weighted sum + + # what about mulitple missings?? + coli = div(i, size(dataT, 1)) + 1 + # println("$i $coli $(size(mdata))") + w = sum(dists[coli]) ./ dists[coli] + # Inverse distance weighting + if isnan(sum(w)) || isinf(sum(w)) + # if some distance is zero, keep original value + data[i] = data[i] + else + data[i] = sum(w .* dataT[idxs[coli]]) + end + end + end + end + + return data + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8ce1b4b..acb10b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,15 @@ using Impute -using Tables -using Test using AxisArrays +using Combinatorics using DataFrames using Dates +using Distances +using Random using RDatasets using Statistics using StatsBase -using Random +using Tables +using Test using Impute: Drop, @@ -24,6 +26,15 @@ using Impute: interp, chain +function add_missings(X, ratio=0.1) + result = Matrix{Union{Float64, Missing}}(X) + + for i in 1:floor(Int, length(X) * ratio) + result[rand(1:length(X))] = missing + end + + return result +end @testset "Impute" begin # Defining our missing datasets @@ -515,5 +526,35 @@ using Impute: end end + @testset "KNN" begin + # Test a case with few variable + # (e.g., only a few variables, only ) + @testset "Data - few variables" begin + data = Matrix(dataset("Ecdat", "Electricity")) + X = add_missings(data) + + knn_imputed = impute(copy(X), Impute.KNN(; num_nn=3, dist=Euclidean(), context=Context(; limit = 1.0))) + mean_imputed = impute(copy(X), Fill(; context=Context(; limit=1.0))) + + # If we don't have enough variables then SVD imputation will probably perform + # about as well as mean imputation. + @test nrmsd(knn_imputed, data) > nrmsd(mean_imputed, data) * 0.9 + end + + @testset "Data - random variables" begin + M = rand(100, 200) + data = M * M' + X = add_missings(data) + + knn_imputed = impute(copy(X), Impute.KNN(; num_nn=3, dist=Euclidean(), context=Context(; limit = 1.0))) + mean_imputed = impute(copy(X), Fill(; context=Context(; limit=1.0))) + + # If most of the variance in the original data can't be explained by a small + # subset of the eigen values in the svd decomposition then our low rank approximations + # won't perform very well. + @test nrmsd(knn_imputed, data) > nrmsd(mean_imputed, data) * 0.9 + end + end + include("deprecated.jl") end