-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4546b9a
commit 28e8203
Showing
5 changed files
with
260 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
KNN <: Imputor | ||
Imputation using k-Nearest Neighbor algorithm. | ||
# Keyword Arguments | ||
* `num_nn::Int`: number of nearest neighbors | ||
* `dist::MinkowskiMetric`: distance metric suppports by `NearestNeighbors.jl` (Euclidean, Chebyshev, Minkowski and Cityblock) | ||
* `on_complete::Function`: a function to run when imputation is complete | ||
# Reference | ||
* Troyanskaya, Olga, et al. "Missing value estimation methods for DNA microarrays." Bioinformatics 17.6 (2001): 520-525. | ||
""" | ||
# TODO : Support Categorical Distance (NearestNeighbors.jl support needed) | ||
struct KNN{M} <: Imputor where M <: NearestNeighbors.MinkowskiMetric | ||
num_nn::Int | ||
dist::M | ||
missing_neighbors_threshold::AbstractFloat | ||
context::AbstractContext | ||
end | ||
|
||
function KNN(; num_nn=1, missing_neighbors_threshold=0.5, | ||
dist=Euclidean(), context=Context()) | ||
|
||
if num_nn < 1 | ||
throw(ArgumentError("The number of nearset neighbors should be greater than 0")) | ||
end | ||
|
||
# to exclude missing value itself | ||
KNN(num_nn + 1, dist, missing_neighbors_threshold, context) | ||
end | ||
|
||
function impute!(data::AbstractMatrix{<:Union{T, Missing}}, imp::KNN) where T<:Real | ||
|
||
imp.context() do ctx | ||
# transpose to D x N for KDTree | ||
transposed = float.(collect(transpose(data))) | ||
|
||
# Get our before and after views of our missing and non-missing data | ||
mmask = ismissing.(transposed) | ||
omask = .!mmask | ||
|
||
mdata = transposed[mmask] | ||
odata = transposed[omask] | ||
|
||
# fill missing value as mean first | ||
impute!(transposed, Fill(; value=mean, context=ctx)) | ||
|
||
# disallow missing to compute kdtree | ||
transposed = disallowmissing(transposed) | ||
|
||
kdtree = KDTree(transposed, imp.dist) | ||
idxs, dists = NearestNeighbors.knn(kdtree, transposed, imp.num_nn, true) | ||
|
||
idxes = CartesianIndices(transposed) | ||
|
||
# iterate all value | ||
for I in CartesianIndices(transposed) | ||
if mmask[I] == 1 | ||
w = 1.0 ./ dists[I[2]] | ||
#@show idxs[I[2]][2], dists[I[2]][2], sum(w[2:end]) | ||
|
||
missing_neighbors = ismissing.(transposed[:, idxs[I[2]]][:, 2:end]) | ||
# exclude missing value itself because distance would be zero | ||
if isnan(sum(w[2:end])) || isinf(sum(w[2:end])) || sum(w[2:end]) == 0.0 | ||
# if distance is zero, keep mean imputation | ||
transposed[I] = transposed[I] | ||
elseif count(!=(0), mapslices(sum, missing_neighbors, dims=[1])) > | ||
imp.num_nn * imp.missing_neighbors_threshold | ||
# If too many neighbors are also missing, fallback to mean imputation | ||
# get column and check how many neighbors are also missing | ||
transposed[I] = transposed[I] | ||
else | ||
# Inverse distance weighting | ||
# idxs = num_points x num_nearestpoints | ||
wsum = w .* transposed[I[1], idxs[I[2]]] | ||
transposed[I] = sum(wsum[2:end]) / sum(w[2:end]) | ||
end | ||
end | ||
end | ||
|
||
data = transposed' | ||
|
||
data | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters