diff --git a/EMBEDR/embedr.py b/EMBEDR/embedr.py index ba6c97a..3802c91 100644 --- a/EMBEDR/embedr.py +++ b/EMBEDR/embedr.py @@ -2600,63 +2600,40 @@ def sweep_boxplot(self, sweep_type='pvalues', **kwargs): return axis - def sweep_lineplot(self, sweep_type='pvalues', metadata=None, labels=None, - params_2_highlight=None, fig=None, axis=None, - xticks=[], xticklabels=[], **kwargs): + def sweep_lineplot(self, sweep_type='pvalues', metadata=None, label=None, + labels_2_show=None, fig=None, axis=None, **kwargs): """Generates lineplots of each cell vs swept hyperparameters.""" hp_values = np.sort(self.sweep_values) - n_hp = len(hp_values) - - if params_2_highlight is None: - hp_2_hl = [] - else: - hp_2_hl = params_2_highlight - hp_2_hl_idx = np.array([ii for ii, hp in enumerate(hp_values) - if hp in hp_2_hl]).astype(int) - - if len(xticks) < 1: - if len(hp_2_hl) > 0: - xticks = [0] + hp_2_hl_idx.tolist() + [n_hp - 1] - else: - if n_hp <= 5: - xticks = np.arange(n_hp) - else: - xticks = np.unique(np.linspace(0, n_hp, 5)) - xticks = np.clip(xticks, 0, n_hp - 1) - xticks = np.asarray(xticks).astype(int) - - if len(xticklabels) == 0: - xtlabs = np.asarray([self.kEff[hp_values[idx]] - for idx in xticks]) - xticklabels = human_round(xtlabs).squeeze().astype(int) if sweep_type.lower() == 'pvalues': values_dict = self.pValues + elif sweep_type.lower() == 'ees': + values_dict = {key: np.median(val, axis=0) + for key, val in self.data_EES.items()} - if (metadata is None) or (labels is None): - from EMBEDR.plots.sweep_lineplots import sweep_lineplot + if (metadata is None) or (label is None): + from EMBEDR.plots.sweep_lineplots import SweepLineplot - return sweep_lineplot(hp_values, - values_dict, - fig=fig, - axis=axis, - xticks=xticks, - xticklabels=xticklabels, - **kwargs) + plotObj = SweepLineplot(hp_values, + values_dict, + fig=fig, + axis=axis, + **kwargs) else: - from EMBEDR.plots.sweep_lineplots import sweep_lineplot_byCat - return sweep_lineplot_byCat(hp_values, - values_dict, - metadata, - labels, - fig=fig, - xticks=xticks, - xticklabels=xticklabels, - **kwargs) + from EMBEDR.plots.sweep_lineplots import SweepLineplot_Category + plotObj = SweepLineplot_Category(hp_values, + values_dict, + metadata, + label, + labels_2_show=labels_2_show, + fig=fig, + axes=axis, + **kwargs) + return plotObj.plot() def plot_embedding(self, param_2_plot='optimal', diff --git a/EMBEDR/plots/embedr_scatterplots.py b/EMBEDR/plots/embedr_scatterplots.py index fb193e9..082cebe 100644 --- a/EMBEDR/plots/embedr_scatterplots.py +++ b/EMBEDR/plots/embedr_scatterplots.py @@ -196,7 +196,7 @@ def __init__(self, *args, **kwargs): if isinstance(self._cmap, putl.CategoricalFadingCMap): self.cbar_ticks = self._cmap.change_points else: - self.cbar_ticks = [0, 2, 3, 4, 5] + self.cbar_ticks = [0, 1, 2, 3, 5] if self.cbar_ticklabels is None: if self.log_labels: diff --git a/EMBEDR/plots/sweep_lineplots.py b/EMBEDR/plots/sweep_lineplots.py index d59985c..1532625 100644 --- a/EMBEDR/plots/sweep_lineplots.py +++ b/EMBEDR/plots/sweep_lineplots.py @@ -462,6 +462,9 @@ def plot(self): self.axes = gridspec.subplots(sharex=self.ax_sharex, sharey=self.ax_sharey) + if self.axes.ndim != 2: + self.axes = self.axes.reshape(self.n_rows, self.n_cols) + for rowNo in range(len(self.axes)): for colNo in range(self.n_cols): axis = self.axes[rowNo][colNo] diff --git a/EMBEDR/plotting_utility.py b/EMBEDR/plotting_utility.py index f4709a5..9e7b209 100644 --- a/EMBEDR/plotting_utility.py +++ b/EMBEDR/plotting_utility.py @@ -320,7 +320,7 @@ class CategoricalFadingCMap(object): """ def __init__(self, - change_points=[0, 2, 3, 4, 5], + change_points=[0, 1, 2, 3, 5], base_cmap='colorblind', cmap_idx=None, cmap_dx=0.001,