diff --git a/EMBEDR/embedr.py b/EMBEDR/embedr.py index 3db7f87..26257a5 100644 --- a/EMBEDR/embedr.py +++ b/EMBEDR/embedr.py @@ -1860,20 +1860,17 @@ def calculate_pValues(self): raise ValueError(err_str) def plot(self, - ax=None, - cax=None, - show_cbar=True, - embed_2_show=0, + plot_type='pvalue', + metadata=None, + is_categorical=False, plot_data=True, - cbar_ticks=None, - cbar_ticklabels=None, - pVal_clr_change=[0, 1, 2, 3, 4], - scatter_s=5, - scatter_alpha=0.4, - scatter_kwds={}, - text_kwds={}, - cbar_kwds={}, - cite_EMBEDR=True): + embed_2_show=0, + fig=None, + axis=None, + cbar_ax=None, + show_cbar=True, + cite_EMBEDR=True, + **plot_kwds): """Generates scatter plot of embedded data colored by EMBEDR p-value Parameters @@ -1938,65 +1935,87 @@ def plot(self, > fig, ax = plt.subplots(1, 1) > ax = embedr_obj.plot(ax=ax) """ - import matplotlib.pyplot as plt - - fig = plt.gcf() - if ax is None: - ax = fig.gca() if plot_data: - Y = self.data_Y[embed_2_show] + plot_Y = self.data_Y[embed_2_show] else: - Y = self.null_Y[embed_2_show] - - pVal_cmap = putl.CategoricalFadingCMap(change_points=pVal_clr_change) - - pVals = -np.log10(self.pValues) - - sort_idx = np.argsort(pVals) - - h_ax = ax.scatter(*Y[sort_idx].T, - s=scatter_s, - c=pVals[sort_idx], - cmap=pVal_cmap.cmap, - norm=pVal_cmap.cnorm, - alpha=scatter_alpha, - **scatter_kwds) - - if cite_EMBEDR: - _ = ax.text(0.02, 0.02, - "Made with the EMBEDR package.", - fontsize=6, - transform=ax.transAxes, - ha='left', - va='bottom', - **text_kwds) - - if show_cbar: - cbar_ax = fig.colorbar(h_ax, - ax=ax, - cax=cax, - boundaries=pVal_cmap.cnorm.boundaries, - ticks=[], - **cbar_kwds) - cbar_ax.ax.invert_yaxis() - - if cbar_ticks is None: - cbar_ticks = pVal_cmap.change_points - cbar_ax.set_ticks(cbar_ticks) - - if cbar_ticklabels is None: - cbar_ticklabels = ['1', - '0.1', - r"$10^{-2}$", - r"$10^{-3}$", - r"$10^{-4}$"] - cbar_ax.set_ticklabels(cbar_ticklabels) - - cbar_ax.ax.tick_params(length=0) - cbar_ax.ax.set_ylabel("EMBEDR p-Value") - - return ax + plot_Y = self.null_Y[embed_2_show] + + if plot_type.lower() in ['pvalue', 'min_pvalue']: + + from EMBEDR.plots.embedr_scatterplots import Scatter_by_pValue + + if plot_type.lower() == 'min_pvalue': + if not hasattr(self, 'opt_min_pVals'): + err_str = f"Current EMBEDR object does not contain" + err_str += f" minimal p-Values resulting from a" + err_str += f" hyperparameter sweep. (Current object might" + err_str += f" not be a cell-optimal embedding generated by" + err_str += f" an EMBEDR_sweep.) If you want to plot" + err_str += f" according to p-Value, use `plot_type`=" + raise TypeError(err_str + f"'pvalue'.") + + pValues = self.opt_min_pVals + else: + pValues = self.pValues + + plotObj = Scatter_by_pValue(plot_Y, + pValues, + fig=fig, + axis=axis, + cbar_ax=cbar_ax, + show_cbar=show_cbar, + cite_EMBEDR=cite_EMBEDR, + **plot_kwds) + + elif plot_type.lower() in ['ees']: + + from EMBEDR.plots.embedr_scatterplots import Scatterplot + + if plot_data: + ees_vals = np.median(self.data_EES, axis=0) + else: + ees_vals = np.median(self.null_EES, axis=0) + + plotObj = Scatterplot(plot_Y, + ees_vals, + fig=fig, + axis=axis, + cbar_ax=cbar_ax, + show_cbar=show_cbar, + cite_EMBEDR=cite_EMBEDR, + **plot_kwds) + + elif (metadata is not None) and (plot_type in metadata): + + if is_categorical: + from EMBEDR.plots.embedr_scatterplots import Scattergory + + plotObj = Scattergory(plot_Y, + plot_type, + metadata, + fig=fig, + axis=axis, + cbar_ax=cbar_ax, + show_cbar=show_cbar, + cite_EMBEDR=cite_EMBEDR, + **plot_kwds) + + else: + from EMBEDR.plots.embedr_scatterplots import Scatterplot + + labels = metadata[plot_type].values + + plotObj = Scatterplot(plot_Y, + labels, + fig=fig, + axis=axis, + cbar_ax=cbar_ax, + show_cbar=show_cbar, + cite_EMBEDR=cite_EMBEDR, + **plot_kwds) + + return plotObj.plot() class EMBEDR_sweep(object): @@ -2468,12 +2487,12 @@ def get_optimal_hyperparameters(self): if self.verbose >= 5: print(f"Returning optimal '{self.sweep_type}' values!") - return opt_sweep_values + return pVal_arr.min(axis=0), opt_sweep_values def fit_samplewise_optimal(self): try: - opt_hp_vals = self.get_optimal_hyperparameters() + opt_min_pVals, opt_hp_vals = self.get_optimal_hyperparameters() except AttributeError: err_str = f"A hyperparameter sweep must be run before a sample" err_str += f"-wise optimal embedding can be generated!" @@ -2509,11 +2528,14 @@ def fit_samplewise_optimal(self): optObj.fit(self.data_X) + optObj.opt_min_pVals = opt_min_pVals.copy() + self.opt_obj = optObj self.opt_embed = optObj.data_Y[:] self.opt_embed_data_EES = optObj.data_EES[:] self.opt_embed_null_EES = optObj.null_EES[:] self.opt_embed_pValues = optObj.pValues[:] + self.opt_min_pVals = opt_min_pVals.copy() def _get_kEff_from_hp(self, hp_value): sort_idx = np.argsort(self.sweep_values) @@ -2549,7 +2571,7 @@ def sweep_boxplot(self, sweep_type='pvalues', **kwargs): kEff_dict=self.kEff, verbose=self.verbose > 3, **kwargs) - axis = plotObj.make_plot() + axis = plotObj.plot() if sweep_type.lower() == 'ees': from EMBEDR.plots.sweep_boxplots import SweepBoxplot_EES @@ -2560,7 +2582,7 @@ def sweep_boxplot(self, sweep_type='pvalues', **kwargs): kEff_dict=self.kEff, verbose=self.verbose > 3, **kwargs) - axis = plotObj.make_plot() + axis = plotObj.plot() return axis @@ -2590,8 +2612,6 @@ def sweep_lineplot(self, sweep_type='pvalues', metadata=None, labels=None, xticks = np.clip(xticks, 0, n_hp - 1) xticks = np.asarray(xticks).astype(int) - print(xticks) - if len(xticklabels) == 0: xtlabs = np.asarray([self.kEff[hp_values[idx]] for idx in xticks]) diff --git a/EMBEDR/plots/embedr_scatterplots.py b/EMBEDR/plots/embedr_scatterplots.py new file mode 100644 index 0000000..2c060b8 --- /dev/null +++ b/EMBEDR/plots/embedr_scatterplots.py @@ -0,0 +1,327 @@ +from EMBEDR.human_round import human_round +import EMBEDR.plotting_utility as putl +import matplotlib +import matplotlib.gridspec as gs +import matplotlib.pyplot as plt +import numpy as np + + +class Scatterplot(object): + + EMBEDR_text_defaults = {'x': 0.02, + 'y': 0.02, + 's': "Made with the EMBEDR package.", + 'fontsize': 6, + 'ha': 'left', + 'va': 'bottom'} + + def __init__(self, + Y, + labels, + log_labels=False, + axis=None, + show_border=True, + scatter_sizes=3, + scatter_alpha=1, + scatter_kwds=None, + cmap=None, + cmap_kwds=None, + xticks=None, + xticklabels=None, + xlabel=None, + yticks=None, + yticklabels=None, + ylabel=None, + label_size=12, + plot_order=None, + title=None, + title_size=16, + title_pad=0, + show_cbar=True, + cbar_ax=None, + cbar_ticks=None, + cbar_ticklabels=None, + cbar_label=None, + cite_EMBEDR=True, + text_kwds=None, + **kwargs): + + self.Y = Y + self.n_samples, self.n_components = Y.shape + self.labels = labels + + self.log_labels = log_labels + if self.log_labels: + self.labels = np.log10(self.labels) + + self.axis = axis + self.show_border = show_border + + self.sct_s = scatter_sizes + if np.isscalar(self.sct_s): + self.sct_s = np.ones((self.n_samples)) * self.sct_s + self.sct_a = scatter_alpha + self.sct_kwds = {} if scatter_kwds is None else scatter_kwds.copy() + + self.cmap = cmap + self.cmap_kwds = {} if cmap_kwds is None else cmap_kwds.copy() + + self.xticks = xticks + self.xticklabels = xticklabels + self.xlabel = xlabel + self.yticks = yticks + self.yticklabels = yticklabels + self.ylabel = ylabel + self.label_size = label_size + + if plot_order is None: + self.sort_idx = np.arange(self.n_samples) + elif plot_order.lower() == 'asc': + self.sort_idx = np.argsort(labels) + elif plot_order.lower() == 'desc': + self.sort_idx = np.argsort(labels)[::-1] + else: + raise ValueError(f"Unknown plot ordering '{plot_order}'.") + + self.title = title + self.title_size = title_size + self.title_pad = title_pad + + self.show_cbar = show_cbar + self.cax = cbar_ax + self.cbar_ticks = cbar_ticks + self.cbar_ticklabels = cbar_ticklabels + self.cbar_label = cbar_label + + self.cite_EMBEDR = cite_EMBEDR + self.text_kwds = self.EMBEDR_text_defaults.copy() + self.text_kwds.update({} if text_kwds is None else text_kwds.copy()) + + return + + def plot(self, **kwargs): + + if self.axis is None: + self.fig, self.axis = plt.subplots(1, 1, figsize=(8, 6)) + else: + self.fig = self.axis.figure + + spine_alpha = 1 if self.show_border else 0 + self.axis = putl.make_border_axes(self.axis, spine_alpha=spine_alpha) + + self.axis = self._plot(**kwargs) + + if self.cite_EMBEDR: + if 'transform' not in self.text_kwds: + self.text_kwds['transform'] = self.axis.transAxes + _ = self.axis.text(**self.text_kwds) + + self.fig.tight_layout() + + return self.axis + + def _plot(self, **kwargs): + + h_ax = self.axis.scatter(*self.Y[self.sort_idx, :2].T, + c=self.labels[self.sort_idx], + s=self.sct_s[self.sort_idx], + alpha=self.sct_a, + cmap=self.cmap, + **self.sct_kwds) + + if self.show_cbar: + self.cax = self.fig.colorbar(h_ax, + ax=self.axis, + cax=self.cax) + + if self.cbar_ticks is not None: + self.cax.set_ticks(self.cbar_ticks) + if self.cbar_ticklabels is not None: + self.cax.set_ticklabels(self.cbar_ticklabels) + + if self.cbar_label is not None: + self.cax.set_label(self.cbar_label) + + return self.axis + + +class Scatter_by_pValue(Scatterplot): + + def __init__(self, *args, **kwargs): + + try: + self.log_labels = kwargs.pop('log_labels') + except KeyError: + self.log_labels = True + kwargs['log_labels'] = self.log_labels + + super(Scatter_by_pValue, self).__init__(*args, **kwargs) + + if self.log_labels: + self.labels = -self.labels + + if self.xticks is None: + self.xticks = [] + self.xticklabels = [] + + if self.yticks is None: + self.yticks = [] + self.yticklabels = [] + + if self.cmap is None: + self._cmap = putl.CategoricalFadingCMap(**self.cmap_kwds) + self.cmap = self._cmap.cmap + self.cnorm = self._cmap.cnorm + elif isinstance(self.cmap, putl.CategoricalFadingCMap): + self._cmap = self.cmap + self.cmap = self._cmap.cmap + self.cnorm = self._cmap.cnorm + else: + self.cnorm = None + + if (self.cbar_ticks is None): + if isinstance(self._cmap, putl.CategoricalFadingCMap): + self.cbar_ticks = self._cmap.change_points + else: + self.cbar_ticks = [0, 2, 3, 4, 5] + + if self.cbar_ticklabels is None: + if self.log_labels: + self.cbar_ticklabels = [f"{10.**(-ct):.1e}" + for ct in self.cbar_ticks] + else: + self.cbar_ticklabels = [f"{ct:.1e}" for ct in self.cbar_ticks] + + if self.cbar_label is None: + self.cbar_label = r"EMBEDR $p$-Value" + + def _plot(self, **kwargs): + + h_ax = self.axis.scatter(*self.Y[self.sort_idx, :2].T, + c=self.labels[self.sort_idx], + s=self.sct_s[self.sort_idx], + alpha=self.sct_a, + cmap=self.cmap, + norm=self.cnorm, + **self.sct_kwds) + + if self.show_cbar: + bounds = self.cnorm.boundaries if self.cnorm is not None else None + self.cax = self.fig.colorbar(h_ax, + ax=self.axis, + cax=self.cax, + boundaries=bounds) + self.cax.ax.invert_yaxis() + + self.cax.set_ticks(self.cbar_ticks) + self.cax.set_ticklabels(self.cbar_ticklabels) + + self.cax.ax.tick_params(length=0) + self.cax.set_label(self.cbar_label) + + return self.axis + + +class Scattergory(Scatterplot): + + LEGEND_KWDS_DEFAULT = dict(bbox_to_anchor=(1.04,1), loc="upper left") + + def __init__(self, Y, label, metadata, **kwargs): + + try: + self.labels_2_show = kwargs.pop('labels_2_show') + except KeyError: + self.labels_2_show = None + + try: + self.category_kwds = kwargs.pop('category_kwds') + except KeyError: + self.category_kwds = {} + + try: + self.bkgd_label_kwds = kwargs.pop('bkgd_label_kwds') + except KeyError: + self.bkgd_label_kwds = {'color': 'lightgrey', + 's': 3, + 'alpha': 0.5} + + try: + self.show_legend = kwargs.pop('show_legend') + except KeyError: + self.show_legend = True + + self.legend_kwds = self.LEGEND_KWDS_DEFAULT + try: + self.legend_kwds.update(kwargs.pop('legend_kwds')) + except KeyError: + pass + + super(Scattergory, self).__init__(Y, [], **kwargs) + + if self.cmap is None: + self.cmap = 'husl' + + out = putl.process_categorical_label(metadata, + label, + cmap=self.cmap, + **self.category_kwds) + + self._labels = out[0] + self._label_counts = out[1] + self.long_labels = [ll.title() for ll in out[2]] + self._l2i_map = out[3] + self.label_cmap = out[4] + + self._n_labels = len(self._label_counts) + + if self.labels_2_show is None: + self.labels_2_show = self._label_counts.index.values + + def _plot(self, **kwargs): + + for lNo, label in enumerate(self._label_counts.index): + + good_idx = self._labels == label + n_labs = sum(good_idx) + # if verbose: + # print(f"There are {n_labs} samples with label = {label}") + + if label in self.labels_2_show: + + color = self.label_cmap[self._l2i_map[label]] + + if self.show_legend: + title = self.long_labels[lNo] + if "Multipotent" in title: + title = " ".join(title.split(" Multipotent ")) + else: + title = None + + else: + color = self.bkgd_label_kwds['color'] + title = None + + self.axis.scatter(*self.Y[good_idx, :2].T, + color=color, + s=self.sct_s[good_idx], + alpha=self.sct_a, + label=title, + **self.sct_kwds) + + if self.show_legend: + self.axis.legend(**self.legend_kwds) + + return self.axis + + + + + + + + + + + + + diff --git a/EMBEDR/plots/sweep_boxplots.py b/EMBEDR/plots/sweep_boxplots.py index 97f05d8..48b73cc 100644 --- a/EMBEDR/plots/sweep_boxplots.py +++ b/EMBEDR/plots/sweep_boxplots.py @@ -157,13 +157,13 @@ def update_tight_bounds(self): fig_pad=self.fig_pad) return - def make_plot(self, **kwargs): + def plot(self, **kwargs): self.axis = self.fig.add_subplot(self.inner_gs[0]) self.axis = putl.make_border_axes(self.axis, spine_alpha=self.spine_alpha) - self.axis = self._make_plot(**kwargs) + self.axis = self._plot(**kwargs) self.axis.grid(which='major', axis='x', alpha=0) @@ -185,7 +185,7 @@ def make_plot(self, **kwargs): return self.axis - def _make_plot(self, **kwargs): + def _plot(self, **kwargs): return self.axis @@ -263,9 +263,9 @@ def __init__(self, *args, **kwargs): return - def make_plot(self, **kwargs): + def plot(self, **kwargs): - self.axis = super().make_plot(**kwargs) + self.axis = super().plot(**kwargs) self._cAx = self._add_colorbar() @@ -273,7 +273,7 @@ def make_plot(self, **kwargs): return self.axis - def _make_plot(self, **kwargs): + def _plot(self, **kwargs): hl_boxes = {} for hpNo, hpVal in enumerate(self.hp_array): @@ -431,15 +431,15 @@ def __init__(self, *args, **kwargs): if self.box_positions is None: self.box_positions = np.arange(self.n_hp) - def make_plot(self, **kwargs): + def plot(self, **kwargs): - self.axis = super().make_plot(**kwargs) + self.axis = super().plot(**kwargs) self.update_tight_bounds() return self.axis - def _make_plot(self, **kwargs): + def _plot(self, **kwargs): min_pVal, max_pVal = np.inf, -np.inf hl_boxes = {}