diff --git a/EMBEDR/embedr.py b/EMBEDR/embedr.py index 26257a5..0d589db 100644 --- a/EMBEDR/embedr.py +++ b/EMBEDR/embedr.py @@ -1986,6 +1986,19 @@ def plot(self, cite_EMBEDR=cite_EMBEDR, **plot_kwds) + elif plot_type.lower() in ['perplexity']: + + from EMBEDR.plots.embedr_scatterplots import Scatterplot + + plotObj = Scatterplot(plot_Y, + self.perplexity, + fig=fig, + axis=axis, + cbar_ax=cbar_ax, + show_cbar=show_cbar, + cite_EMBEDR=cite_EMBEDR, + **plot_kwds) + elif (metadata is not None) and (plot_type in metadata): if is_categorical: @@ -2641,5 +2654,86 @@ def sweep_lineplot(self, sweep_type='pvalues', metadata=None, labels=None, xticklabels=xticklabels, **kwargs) + def plot_embedding(self, + param_2_plot='optimal', + plot_type='pvalue', + metadata=None, + is_categorical=False, + plot_data=True, + embed_2_show=0, + fig=None, + axis=None, + cbar_ax=None, + show_cbar=True, + cite_EMBEDR=True, + **plot_kwds): + + if param_2_plot == 'optimal': + try: + return self.opt_obj.plot(plot_type=plot_type, + metadata=metadata, + is_categorical=is_categorical, + plot_data=plot_data, + embed_2_show=embed_2_show, + fig=fig, + axis=axis, + cbar_ax=cbar_ax, + show_cbar=show_cbar, + cite_EMBEDR=cite_EMBEDR, + **plot_kwds) + except AttributeError: + err_str = f"Could not plot 'optimal' embedding as it does not" + err_str += f" exist. (Run EMBEDR_sweep.fit_samplewise_optimal" + err_str += f" or set `param_2_plot` to a value that was" + err_str += f" swept by the hyperparameter sweep in `fit`." + raise ValueError(err_str) + + perp = knn = None + if self.sweep_type == 'perplexity': + perp = param_2_plot + else: + knn = param_2_plot + + tmpEmb = EMBEDR(perplexity=perp, + n_neighbors=knn, + kNN_metric=self.kNN_metric, + kNN_alg=self.kNN_alg, + kNN_params=self.kNN_params, + aff_type=self.aff_type, + aff_params=self.aff_params, + n_components=self.n_components, + DRA=self.DRA, + DRA_params=self.DRA_params, + EES_type=self.EES_type, + EES_params=self.EES_params, + pVal_type=self.pVal_type, + n_data_embed=self.n_data_embed, + n_null_embed=self.n_null_embed, + n_jobs=self.n_jobs, + random_state=self.rs, + verbose=0, + do_cache=self.do_cache, + project_name=self.project_name, + project_dir=self.project_dir) + + tmpEmb.data_Y = self.embeddings[param_2_plot] + tmpEmb.data_EES = self.data_EES[param_2_plot] + tmpEmb.null_EES = self.null_EES[param_2_plot] + tmpEmb.pValues = self.pValues[param_2_plot] + + return tmpEmb.plot(plot_type=plot_type, + metadata=metadata, + is_categorical=is_categorical, + plot_data=plot_data, + embed_2_show=embed_2_show, + fig=fig, + axis=axis, + cbar_ax=cbar_ax, + show_cbar=show_cbar, + cite_EMBEDR=cite_EMBEDR, + **plot_kwds) + + + diff --git a/EMBEDR/plots/embedr_scatterplots.py b/EMBEDR/plots/embedr_scatterplots.py index 2c060b8..237d618 100644 --- a/EMBEDR/plots/embedr_scatterplots.py +++ b/EMBEDR/plots/embedr_scatterplots.py @@ -261,6 +261,8 @@ def __init__(self, Y, label, metadata, **kwargs): if self.cmap is None: self.cmap = 'husl' + print(self.category_kwds) + out = putl.process_categorical_label(metadata, label, cmap=self.cmap, @@ -291,21 +293,25 @@ def _plot(self, **kwargs): color = self.label_cmap[self._l2i_map[label]] if self.show_legend: - title = self.long_labels[lNo] + title = self.long_labels[self._l2i_map[label]] if "Multipotent" in title: title = " ".join(title.split(" Multipotent ")) else: title = None + zorder = None + else: color = self.bkgd_label_kwds['color'] title = None + zorder = -1 self.axis.scatter(*self.Y[good_idx, :2].T, color=color, s=self.sct_s[good_idx], alpha=self.sct_a, label=title, + zorder=zorder, **self.sct_kwds) if self.show_legend: diff --git a/EMBEDR/plots/sweep_lineplots.py b/EMBEDR/plots/sweep_lineplots.py index f5d7708..c2368cc 100644 --- a/EMBEDR/plots/sweep_lineplots.py +++ b/EMBEDR/plots/sweep_lineplots.py @@ -161,6 +161,9 @@ def sweep_lineplot_byCat(hyperparam_array, n_labels = len(label_counts) + if labels_2_show == 'all': + labels_2_show = label_counts.index.values + if n_rows is None: n_rows = int(np.ceil(n_labels / n_cols)) diff --git a/EMBEDR/plotting_utility.py b/EMBEDR/plotting_utility.py index a00c28a..fbe9857 100644 --- a/EMBEDR/plotting_utility.py +++ b/EMBEDR/plotting_utility.py @@ -624,7 +624,9 @@ def process_categorical_label(metadata, label, cmap='colorblind', unique_labels = label_counts.index.values if alphabetical_sort: - unique_labels = np.sort(unique_labels) + str_labels = unique_labels.astype(str) + unique_labels = unique_labels[np.argsort(str_labels)] + label_counts = label_counts.reindex(unique_labels) ## Make some nice long labels. long_labels = np.asarray([f"{ll} (N = {label_counts.loc[ll]:d})"