Skip to content

Commit

Permalink
Added a wrapper in the EMBEDR_sweep class to enable use of the EMBEDR…
Browse files Browse the repository at this point in the history
… class plotting routines on embeddings within the sweep.
  • Loading branch information
ejohnson643 committed Nov 6, 2021
1 parent 31ebd8a commit 170875c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 2 deletions.
94 changes: 94 additions & 0 deletions EMBEDR/embedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,19 @@ def plot(self,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

elif plot_type.lower() in ['perplexity']:

from EMBEDR.plots.embedr_scatterplots import Scatterplot

plotObj = Scatterplot(plot_Y,
self.perplexity,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

elif (metadata is not None) and (plot_type in metadata):

if is_categorical:
Expand Down Expand Up @@ -2641,5 +2654,86 @@ def sweep_lineplot(self, sweep_type='pvalues', metadata=None, labels=None,
xticklabels=xticklabels,
**kwargs)

def plot_embedding(self,
param_2_plot='optimal',
plot_type='pvalue',
metadata=None,
is_categorical=False,
plot_data=True,
embed_2_show=0,
fig=None,
axis=None,
cbar_ax=None,
show_cbar=True,
cite_EMBEDR=True,
**plot_kwds):

if param_2_plot == 'optimal':
try:
return self.opt_obj.plot(plot_type=plot_type,
metadata=metadata,
is_categorical=is_categorical,
plot_data=plot_data,
embed_2_show=embed_2_show,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)
except AttributeError:
err_str = f"Could not plot 'optimal' embedding as it does not"
err_str += f" exist. (Run EMBEDR_sweep.fit_samplewise_optimal"
err_str += f" or set `param_2_plot` to a value that was"
err_str += f" swept by the hyperparameter sweep in `fit`."
raise ValueError(err_str)

perp = knn = None
if self.sweep_type == 'perplexity':
perp = param_2_plot
else:
knn = param_2_plot

tmpEmb = EMBEDR(perplexity=perp,
n_neighbors=knn,
kNN_metric=self.kNN_metric,
kNN_alg=self.kNN_alg,
kNN_params=self.kNN_params,
aff_type=self.aff_type,
aff_params=self.aff_params,
n_components=self.n_components,
DRA=self.DRA,
DRA_params=self.DRA_params,
EES_type=self.EES_type,
EES_params=self.EES_params,
pVal_type=self.pVal_type,
n_data_embed=self.n_data_embed,
n_null_embed=self.n_null_embed,
n_jobs=self.n_jobs,
random_state=self.rs,
verbose=0,
do_cache=self.do_cache,
project_name=self.project_name,
project_dir=self.project_dir)

tmpEmb.data_Y = self.embeddings[param_2_plot]
tmpEmb.data_EES = self.data_EES[param_2_plot]
tmpEmb.null_EES = self.null_EES[param_2_plot]
tmpEmb.pValues = self.pValues[param_2_plot]

return tmpEmb.plot(plot_type=plot_type,
metadata=metadata,
is_categorical=is_categorical,
plot_data=plot_data,
embed_2_show=embed_2_show,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)





8 changes: 7 additions & 1 deletion EMBEDR/plots/embedr_scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def __init__(self, Y, label, metadata, **kwargs):
if self.cmap is None:
self.cmap = 'husl'

print(self.category_kwds)

out = putl.process_categorical_label(metadata,
label,
cmap=self.cmap,
Expand Down Expand Up @@ -291,21 +293,25 @@ def _plot(self, **kwargs):
color = self.label_cmap[self._l2i_map[label]]

if self.show_legend:
title = self.long_labels[lNo]
title = self.long_labels[self._l2i_map[label]]
if "Multipotent" in title:
title = " ".join(title.split(" Multipotent "))
else:
title = None

zorder = None

else:
color = self.bkgd_label_kwds['color']
title = None
zorder = -1

self.axis.scatter(*self.Y[good_idx, :2].T,
color=color,
s=self.sct_s[good_idx],
alpha=self.sct_a,
label=title,
zorder=zorder,
**self.sct_kwds)

if self.show_legend:
Expand Down
3 changes: 3 additions & 0 deletions EMBEDR/plots/sweep_lineplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def sweep_lineplot_byCat(hyperparam_array,

n_labels = len(label_counts)

if labels_2_show == 'all':
labels_2_show = label_counts.index.values

if n_rows is None:
n_rows = int(np.ceil(n_labels / n_cols))

Expand Down
4 changes: 3 additions & 1 deletion EMBEDR/plotting_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,9 @@ def process_categorical_label(metadata, label, cmap='colorblind',
unique_labels = label_counts.index.values

if alphabetical_sort:
unique_labels = np.sort(unique_labels)
str_labels = unique_labels.astype(str)
unique_labels = unique_labels[np.argsort(str_labels)]
label_counts = label_counts.reindex(unique_labels)

## Make some nice long labels.
long_labels = np.asarray([f"{ll} (N = {label_counts.loc[ll]:d})"
Expand Down

0 comments on commit 170875c

Please sign in to comment.