Skip to content

Commit

Permalink
Fixed bug in kNN graph loader where the wrong object was called when …
Browse files Browse the repository at this point in the history
…more neighbors were queried.
  • Loading branch information
ejohnson643 committed Oct 31, 2021
1 parent a0c98a1 commit 6499b68
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 160 deletions.
6 changes: 4 additions & 2 deletions EMBEDR/embedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import scipy.stats as st
from sklearn.utils import check_array, check_random_state
from umap import UMAP

import warnings

class EMBEDR(object):

Expand Down Expand Up @@ -753,6 +753,7 @@ def load_kNN_graph(self,
## If a path has been found to a matching kNN graph load it!
with open(kNN_path, 'rb') as f:
kNNObj = pkl.load(f)
kNNObj.verbose = self.verbose

## If the kNN is an ANNOY object, try to load the ANNOY index using the
## current project directory.
Expand All @@ -774,7 +775,7 @@ def load_kNN_graph(self,
print_str += f"... not enough neighbors, querying for more!"
print(print_str)

idx, dst = out.query(X, self._max_nn + 1)
idx, dst = kNNObj.query(X, self._max_nn + 1)
kNNObj.kNN_dst = dst[:, 1:]
kNNObj.kNN_idx = idx[:, 1:]

Expand Down Expand Up @@ -1004,6 +1005,7 @@ def load_affinity_matrix(self,
## If a path has been found to a matching kNN graph load it!
with open(aff_path, 'rb') as f:
affObj = pkl.load(f)
affObj.verbose = self.verbose

if self.verbose >= 3:
print(f"Affinity matrix successfully loaded!")
Expand Down
78 changes: 45 additions & 33 deletions EMBEDR/nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,11 @@ def __init__(self,

def fit(self, X, k_NN):

timer_str = f"Finding {k_NN} nearest neighbors using an exact search"
timer_str += f" and the {self.metric} metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()
if self.verbose:
timer_str = f"Finding {k_NN} nearest neighbors using an exact"
timer_str += f" search and the {self.metric} metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()

## Get the data shape
self.n_samples, self.n_features = X.shape[0], X.shape[1]
Expand All @@ -241,7 +242,8 @@ def fit(self, X, k_NN):
## Return the indices and distances of the k_NN nearest neighbors.
distances, NN_idx = self.index.kneighbors(n_neighbors=k_NN)

timer.__exit__()
if self.verbose:
timer.__exit__()

self.kNN_idx = NN_idx[:, :]
self.kNN_dst = distances[:, :]
Expand All @@ -261,18 +263,20 @@ def query(self, query, k_NN):
NFE.args[0] = err_str + "\n\n" + NFE.args[0]
raise NFE

timer_str = f"Finding {k_NN} nearest neighbors in an existing kNN"
timer_str += f" graph using an exact search and the {self.metric}"
timer_str += f" metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()
if self.verbose:
timer_str = f"Finding {k_NN} nearest neighbors in an existing kNN"
timer_str += f" graph using an exact search and the {self.metric}"
timer_str += f" metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()

## Find the indices and distances to the nearest neighbors of the
## queried points
distances, NN_idx = self.index.kneighbors(query, n_neighbors=k_NN)

## Stop the watch
timer.__exit__()
if self.verbose:
timer.__exit__()

## Return the indices of the nearest neighbors to the queried points
## *in the original graph* and the distances to those points.
Expand Down Expand Up @@ -367,10 +371,11 @@ def __init__(self,

def fit(self, X, k_NN):

timer_str = f"Finding {k_NN} nearest neighbors using an approximate"
timer_str += f" search and the {self.metric} metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()
if self.verbose:
timer_str = f"Finding {k_NN} nearest neighbors using an"
timer_str += f" approximate search and the {self.metric} metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()

X = check_array(X, accept_sparse=True, ensure_2d=True)

Expand Down Expand Up @@ -416,7 +421,8 @@ def getnns(ii):
Parallel(n_jobs=self.n_jobs, require="sharedmem")(
delayed(getnns)(ii) for ii in range(self.n_samples))

timer.__exit__()
if self.verbose:
timer.__exit__()

self.kNN_idx = NN_idx[:, :]
self.kNN_dst = distances[:, :]
Expand All @@ -438,11 +444,12 @@ def query(self, query, k_NN):
err_str += f" constructed! (Run kNNIndex.fit(X, k_NN))"
raise ValueError(err_str)

timer_str = f"Finding {k_NN} nearest neighbors to query points in"
timer_str += f" existing kNN graph using an approximate search and the"
timer_str += f" '{self.metric}'' metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()
if self.verbose:
timer_str = f"Finding {k_NN} nearest neighbors to query points in"
timer_str += f" existing kNN graph using an approximate search and"
timer_str += f" the '{self.metric}'' metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()

## Check query shape, if 1D array, reshape.
if query.ndim == 1:
Expand Down Expand Up @@ -480,7 +487,8 @@ def getnns(ii):
delayed(getnns)(ii) for ii in range(n_query)
)

timer.__exit__()
if self.verbose:
timer.__exit__()

return NN_idx, distances

Expand Down Expand Up @@ -673,10 +681,11 @@ def check_metric(self, metric):

def fit(self, X, k_NN):

timer_str = f"Finding {k_NN} approximate nearest neighbors using"
timer_str += f" NNDescent and the '{self.metric}' metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()
if self.verbose:
timer_str = f"Finding {k_NN} approximate nearest neighbors using"
timer_str += f" NNDescent and the '{self.metric}' metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()

## Get the data shape
self.n_samples, self.n_features = X.shape[0], X.shape[1]
Expand Down Expand Up @@ -739,7 +748,8 @@ def fit(self, X, k_NN):

raise ValueError(err_str)

timer.__exit__()
if self.verbose:
timer.__exit__()

# return NN_idx[:, 1:], distances[:, 1:]
self.kNN_idx = NN_idx[:, 1:]
Expand All @@ -756,15 +766,17 @@ def query(self, query, k_NN):
err_str += f" constructed! (Run kNNIndex.fit(X, k_NN))"
raise ValueError(err_str)

timer_str = f"Finding {k_NN} approximate nearest neighbors to query "
timer_str += f" points in the existing NN graph using `pynndescent`"
timer_str += f" and the '{self.metric}' metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()
if self.verbose:
timer_str = f"Finding {k_NN} approximate nearest neighbors to"
timer_str += f" query points in the existing NN graph using"
timer_str += f" `pynndescent` and the '{self.metric}' metric..."
timer = utl.Timer(timer_str, verbose=self.verbose)
timer.__enter__()

NN_idx, distances = self.index.query(query, k=k_NN)

timer.__exit__()
if self.verbose:
timer.__exit__()

return NN_idx, distances

Expand Down
38 changes: 27 additions & 11 deletions EMBEDR/plots/EMBEDR_Figure_01v1_DimRed_Zoology.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,22 @@ def make_figure(X, cluster_labels, clusters_2_label=None, label_colors=None,

if label_colors is None:
cblind_cmap = sns.color_palette('colorblind')
l2cl = {cl: (ii + 3) % 10 for ii, cl in enumerate(clust_2_label)}
label_colors = [cblind_cmap[l2cl[ll]] if (ll in clust_2_label)
else 'lightgrey' for ll in labels]
l2cl = {cl: (ii + 3) % 10
for ii, cl in enumerate(clusters_2_label)}
label_colors = [cblind_cmap[l2cl[ll]] if (ll in clusters_2_label)
else 'lightgrey' for ll in cluster_labels.squeeze()]
label_colors = np.asarray(label_colors)

if label_sizes is None:
label_sizes = [3 if (ll in clust_2_label) else 1 for ll in labels]
label_sizes = [3 if (ll in clusters_2_label) else 1
for ll in cluster_labels]
label_sizes = np.asarray(label_sizes)

if DRAs is None:
## Set parameters at which to plot data
DRAs = [('tSNE', 7),
DRAs = [('tSNE', 9),
('UMAP', 15),
('tSNE', 250),
('tSNE', 350),
('UMAP', 400)]

if project_name is None:
Expand Down Expand Up @@ -89,11 +93,23 @@ def make_figure(X, cluster_labels, clusters_2_label=None, label_colors=None,
**EMBEDR_params)
Y, _ = embObj.get_tSNE_embedding(X)

if alg.lower() in ['umap']:
embObj = EMBEDR(X=X,
n_neighbors=param,
DRA='umap',
n_data_embed=1,
n_jobs=-1,
project_name=project_name,
project_dir=project_dir,
**EMBEDR_params)
Y, _ = embObj.get_UMAP_embedding(X)

rowNo = int(algNo / n_cols)
colNo = int(algNo % n_cols)
ax = main_axes[rowNo][colNo]

add_plot_color_by_cluster(Y, cluster_labels)
add_plot_color_by_cluster(Y[0], cluster_labels, ax, label_colors,
label_sizes, clusters_2_label)
return


Expand Down Expand Up @@ -128,15 +144,15 @@ def set_main_grid(fig_wid=7.2, fig_hgt=5.76, n_rows=2, n_cols=2,


def add_plot_color_by_cluster(Y, cluster_labels, ax, label_colors, label_sizes,
clusters_2_label):
clusters_2_label, scatter_alpha=0.2):

ax.scatter(*Y.T,
c=label_colors,
s=label_sizes,
alpha=scatter_alpha)

for cNo, cluster in enumerate(clusters_2_label):
good_idx = (cluster_labels == clusters)
good_idx = (cluster_labels == cluster).squeeze()

cluster_median = np.median(Y[good_idx], axis=0)

Expand Down Expand Up @@ -329,8 +345,8 @@ def add_plot_color_by_cluster(Y, cluster_labels, ax, label_colors, label_sizes,
cell_ont_ids = sorted(cell_ont_ids,
key=lambda cO: -cell_ont_counts[cO])

cell_ont_labels = [f"{cO} (N = {cell_ont_counts[cO]})"
for cO in cell_ont_ids]
cell_ont_labels = np.asarray([f"{cO} (N = {cell_ont_counts[cO]})"
for cO in cell_ont_ids])

cell_ont_cmap = sns.color_palette('husl', len(cell_ont_ids))

Expand Down
Loading

0 comments on commit 6499b68

Please sign in to comment.