diff --git a/EMBEDR/embedr.py b/EMBEDR/embedr.py index d75bcc3..e307c83 100644 --- a/EMBEDR/embedr.py +++ b/EMBEDR/embedr.py @@ -3,6 +3,7 @@ from EMBEDR._affinity import calculate_kEff as _calc_kEff_from_sparse import EMBEDR.callbacks as cb import EMBEDR.ees as ees +from EMBEDR.human_round import * import EMBEDR.nearest_neighbors as nn import EMBEDR.plotting_utility as putl from EMBEDR.tsne import tSNE_Embed @@ -1945,12 +1946,7 @@ def plot(self, else: Y = self.null_Y[embed_2_show] - [pVal_cmap, - pVal_cnorm] = putl.make_categ_cmap(change_points=pVal_clr_change) - - color_bounds = np.linspace(pVal_clr_change[0], - pVal_clr_change[-1], - pVal_cmap.N) + pVal_cmap = putl.CategoricalFadingCMap(change_points=pVal_clr_change) pVals = -np.log10(self.pValues) @@ -1959,8 +1955,8 @@ def plot(self, h_ax = ax.scatter(*Y[sort_idx].T, s=scatter_s, c=pVals[sort_idx], - cmap=pVal_cmap, - norm=pVal_cnorm, + cmap=pVal_cmap.cmap, + norm=pVal_cmap.cnorm, alpha=scatter_alpha, **scatter_kwds) @@ -1977,13 +1973,13 @@ def plot(self, cbar_ax = fig.colorbar(h_ax, ax=ax, cax=cax, - boundaries=color_bounds, + boundaries=pVal_cmap.cnorm.boundaries, ticks=[], **cbar_kwds) cbar_ax.ax.invert_yaxis() if cbar_ticks is None: - cbar_ticks = pVal_clr_change + cbar_ticks = pVal_cmap.change_points cbar_ax.set_ticks(cbar_ticks) if cbar_ticklabels is None: @@ -2538,3 +2534,246 @@ def _get_hp_from_kEff(self, kEff): kEff = [kEff] return interpolate(x_coords, y_coords, np.asarray(nn)).squeeze() + + def plot_sweep(self, + fig=None, + gridspec=None, + box_widths=None, + box_positions=None, + box_notch=True, + box_bootstrap=100, + box_whiskers=(1, 99), + box_color='grey', + box_fliers=None, + box_props=None, + box_hl_color='grey', + box_hl_props=None, + values_2_highlight=None, + xLabel_idx=None, + xLabel=None, + xLabel_size=16, + xLim=None, + pVal_cmap=None, + fig_size=(12, 5), + fig_pad=0.4, + fig_ppad=0.01, + show_borders=False, + bot_wpad=0.0, + bot_hpad=0.0, + cax_ticklabels=None, + cax_width_frac=1.3, + cax_w2h_ratio=0.1): + """Generates scatter plot of embedded data colored by EMBEDR p-value + + Parameters + ---------- + """ + + import matplotlib.gridspec as gs + import matplotlib.pyplot as plt + + hp_array = np.sort(self.sweep_values) + + if box_fliers is None: + box_fliers = {'marker': ".", + 'markeredgecolor': box_color, + 'markersize': 2, + 'alpha': 0.5} + + if box_props is None: + box_props = {'alpha': 0.5, + 'color': box_color, + 'fill': True} + + if box_hl_props is None: + box_hl_props = box_props.copy() + box_hl_props.update({"alpha": 0.9, "color": box_hl_color}) + + if values_2_highlight is None: + values_2_highlight = [] + + box_patches = ['boxes', 'whiskers', 'fliers', 'caps', 'medians'] + + if fig is None: + fig = plt.figure(figsize=fig_size) + + if gridspec is None: + gridspec = fig.add_gridspec(1, 1) + + if pVal_cmap is None: + pVal_cmap = putl.CategoricalFadingCMap() + + ## Set up large axes + spine_alpha = 1 if show_borders else 0 + bot_ax = fig.add_subplot(gridspec[0]) + bot_ax = putl.make_border_axes(bot_ax, xticks=[], yticks=[], + spine_alpha=spine_alpha) + + ## Set up floating bottom gridspec + bot_gs = gs.GridSpec(nrows=1, ncols=1, + wspace=bot_wpad, hspace=bot_hpad) + + ax = putl.make_border_axes(fig.add_subplot(bot_gs[0]), + yticklabels=[], + yticks=-np.sort(pVal_cmap.change_points), + spine_alpha=1) + + putl.update_tight_bounds(fig, bot_gs, gridspec[0], w_pad=bot_wpad, + h_pad=bot_hpad, fig_pad=fig_pad) + + hl_boxes = {} + hl_idx = [] + for hpNo, hpVal in enumerate(hp_array): + + if box_widths is not None: + try: + box_wid = box_widths[hpNo] + except TypeError as err: + box_wid = box_widths + else: + box_wid = 0.8 + + if box_positions is not None: + try: + box_pos = [box_positions[hpNo]] + except TypeError as err: + box_pos = [box_positions] + else: + box_pos = [hpNo] + + if hpVal in values_2_highlight: + box_pps = box_hl_props.copy() + box_col = box_color + hl_idx.append(hpNo) + else: + box_pps = box_props.copy() + box_col = box_hl_color + + box = ax.boxplot(np.log10(self.pValues[hpVal]), + widths=box_wid, + positions=box_pos, + notch=box_notch, + bootstrap=box_bootstrap, + patch_artist=True, + whis=box_whiskers, + boxprops=box_pps, + flierprops=box_fliers) + + for item in box_patches: + plt.setp(box[item], color=box_col) + + if hpVal in values_2_highlight: + hl_boxes[hpVal] = box['boxes'][0] + + if xLabel_idx is None: + if values_2_highlight: + xLabel_idx = [0] + hl_idx + [hpNo] + else: + if len(hp_array) <= 5: + xLabel_idx = np.arange(len(hp_array)) + else: + xLabel_idx = np.linspace(0, len(hp_array), 5) + xLabel_idx = human_round(xLabel_idx) + xLabel_idx = np.asarray(xLabel_idx).astype(int) + + ax.set_xticks(xLabel_idx) + xticks = [f"{int(self.kEff[hp_array[idx]])}" for idx in xLabel_idx] + xticks = human_round(np.asarray(xticks).squeeze()) + ax.grid(which='major', axis='x', alpha=0) + ax.set_xticklabels(xticks) + + if xLim is None: + xLim = [-1, len(hp_array)] + + ax.set_xlabel(r"$ k_{Eff}$", fontsize=xLabel_size, labelpad=0) + ax.set_xlim(*xLim) + + # ax.set_yticks(-np.sort(pVal_cmap.change_points)) + # ax.set_yticklabels([]) + + ax.set_ylim(-pVal_cmap.change_points.max(), + -pVal_cmap.change_points.min()) + + ax.tick_params(pad=-3) + + ## Update the figure again... + putl.update_tight_bounds(fig, bot_gs, gridspec[0], w_pad=bot_wpad, + h_pad=bot_hpad, fig_pad=fig_pad) + + ## Colorbar parameters + if cax_ticklabels is None: + cax_ticklabels = [f"{10.**(-cp):.1e}" + for cp in pVal_cmap.change_points] + + inv_ax_trans = ax.transAxes.inverted() + fig_trans = fig.transFigure + + ## Convert from data to display + min_pVal = np.min([np.log10(self.pValues[hp].min()) + for hp in self.sweep_values]) + min_pVal = np.min([min_pVal, -pVal_cmap.change_points.max()]) + max_pVal = np.max([np.log10(self.pValues[hp].max()) + for hp in self.sweep_values]) + max_pVal = np.min([max_pVal, -pVal_cmap.change_points.min()]) + min_pVal_crds = ax.transData.transform([xLim[0], min_pVal]) + max_pVal_crds = ax.transData.transform([xLim[0], max_pVal]) + + # print(f"min_pVal_crds: {min_pVal_crds}") + # print(f"max_pVal_crds: {max_pVal_crds}") + + ## Convert from display to figure coordinates + cFigX0, cFigY0 = fig.transFigure.inverted().transform(min_pVal_crds) + cFigX1, cFigY1 = fig.transFigure.inverted().transform(max_pVal_crds) + + # print(f"cFig0: {cFigX0:.4f}, {cFigY0:.4f}") + # print(f"cFig1: {cFigX1:.4f}, {cFigY1:.4f}") + + cFig_height = np.abs(cFigY1 - cFigY0) + cFig_width = cax_w2h_ratio * cFig_height + + # print(f"The color bar will be {cFig_width:.4f} x {cFig_height:.4f}") + + cAxX0, cAxY0 = cFigX0 - cax_width_frac * cFig_width, cFigY0 + cAxX1, cAxY1 = cAxX0 + cFig_width, cFigY0 + cFig_height + + ## Convert from Figure back into Axes + [cAxX0, + cAxY0] = inv_ax_trans.transform(fig_trans.transform([cAxX0, cAxY0])) + [cAxX1, + cAxY1] = inv_ax_trans.transform(fig_trans.transform([cAxX1, cAxY1])) + + # print(f"cAx0: {cAxX0:.4f}, {cAxY0:.4f}") + # print(f"cAx1: {cAxX1:.4f}, {cAxY1:.4f}") + + cAx_height = np.abs(cAxY1 - cAxY0) + cAx_width = np.abs(cAxX1 - cAxX0) + + # print(f"The color bar will be {cAx_width:.4f} x {cAx_height:.4f}") + + caxIns = ax.inset_axes([cAxX0, cAxY0, cAx_width, cAx_height]) + caxIns = putl.make_border_axes(caxIns, spine_alpha=0) + + hax = plt.scatter([], [], c=[], s=[], cmap=pVal_cmap.cmap, + norm=pVal_cmap.cnorm) + cAx = fig.colorbar(hax, cax=caxIns, ticks=[], + boundaries=pVal_cmap.cnorm.boundaries) + cAx.ax.invert_yaxis() + + cAx.set_ticks(pVal_cmap.change_points) + cAx.set_ticklabels(cax_ticklabels) + cAx.ax.tick_params(length=0) + cAx.ax.yaxis.set_ticks_position('left') + + cAx.ax.set_ylabel(r"EMBEDR $p$-Value", + fontsize=xLabel_size, + labelpad=2) + cAx.ax.yaxis.set_label_position('left') + + ## Update the figure again... + putl.update_tight_bounds(fig, bot_gs, gridspec[0], w_pad=bot_wpad, + h_pad=bot_hpad, fig_pad=fig_pad) + + return ax + + + diff --git a/EMBEDR/human_round.py b/EMBEDR/human_round.py index aed0e5f..2d66b20 100644 --- a/EMBEDR/human_round.py +++ b/EMBEDR/human_round.py @@ -54,8 +54,7 @@ def human_round(x, Options are 'up', 'down', and 'none'. """ - if not is_iterable(x): - x = np.asarray([x]) + x = np.asarray([x]).squeeze().astype(float) if not inplace: x = x.copy() diff --git a/EMBEDR/plots/EMBEDR_Figure_01v1_DimRed_Zoology.py b/EMBEDR/plots/EMBEDR_Figure_01v1_DimRed_Zoology.py index 17d95c0..59f36c6 100644 --- a/EMBEDR/plots/EMBEDR_Figure_01v1_DimRed_Zoology.py +++ b/EMBEDR/plots/EMBEDR_Figure_01v1_DimRed_Zoology.py @@ -25,9 +25,13 @@ ############################################################################### """ from EMBEDR.embedr import EMBEDR, EMBEDR_sweep +from EMBEDR.human_round import human_round import EMBEDR.plotting_utility as putl +import EMBEDR.utility as utl +import anndata as ad import matplotlib +import matplotlib.gridspec as gs import matplotlib.pyplot as plt import numpy as np import os @@ -37,6 +41,350 @@ warnings.filterwarnings("ignore", message="This figure includes Axes that") warnings.filterwarnings("ignore", message="tight_layout not applied: ") +warnings.filterwarnings("ignore", message="Creating an ndarray from ragged") + + +def _make_figure_grid(fig_size=(7.2, 5.76), + n_rows=2, + n_cols=2, + show_all_borders=False, + wspace=0.005, + hspace=0.01, + spines_2_show='all', + spine_alpha=0.5, + spine_width=1.0): + + back_spine_alpha = 0 + if show_all_borders: + back_spine_alpha = 1 + + fig = plt.figure(figsize=fig_size) + + back_axis = fig.add_subplot(111) + back_axis = putl.make_border_axes(back_axis, spine_alpha=back_spine_alpha) + + main_gs = fig.add_gridspec(nrows=n_rows, ncols=n_cols, + wspace=wspace, hspace=hspace) + + main_axes = [] + for rowNo in range(n_rows): + axes_row = [] + for colNo in range(n_cols): + ax = fig.add_subplot(main_gs[rowNo, colNo]) + ax = putl.make_border_axes(ax, spines_2_show=spines_2_show, + spine_alpha=spine_alpha, + spine_width=spine_width) + axes_row.append(ax) + main_axes.append(axes_row) + + return fig, back_axis, main_gs, main_axes + + +def _add_plot_colored_by_cluster(Y, + labels, + axis, + colors, + sizes, + labels_2_hl, + scatter_alpha=0.2): + + hax = axis.scatter(*Y.T, c=colors, s=sizes, alpha=scatter_alpha) + + for lNo, lab in enumerate(labels_2_hl): + good_idx = (labels == lab).squeeze() + + label_median = np.median(Y[good_idx], axis=0) + + axis.text(*label_median, "{}".format(lNo + 1), fontsize=12, + fontweight='bold', va='center', ha='center') + + return hax + + +def _add_plot_colored_by_var(Y, + labels, + axis, + sizes, + scatter_alpha=0.2, + reverse_label=False): + + sort_idx = np.argsort(labels) + if reverse_label: + sort_idx = sort_idx[::-1] + + hax = axis.scatter(*Y[sort_idx].T, c=labels[sort_idx], s=sizes[sort_idx], + alpha=scatter_alpha) + + return hax + + + + + +def EMBEDR_Figure_01(X, + metadata=None, + data_dir=None, + embedding_params=None, + EMBEDR_params=None, + project_dir="./", + project_name="EMBEDR_Figure_01v1_DimRedZoology", + color_by_cluster=True, + label_name="cell_ontology_class", + labels_2_hl=None, + label_colors=None, + label_sizes=None, + label_params=None, + grid_params=None, + n_rows=2, + n_cols=2, + scatter_alpha=0.2, + title_size=14, + title_pad=-15, + add_panel_numbers=False, + fig_dir="./", + fig_pad=None): + + if metadata is None: + load_metadata = True + + data_name = "" + if isinstance(X, str): + data_name = X.title() + if load_metadata: + X, metadata = utl.load_data(X, data_dir=data_dir) + else: + X = utl.load_data(X, data_dir=data_dir, load_metadata=False) + + if metadata is None: + err_str = f"Metadata must be either loadable with `utl.load_data`" + err_str += f" or provided. Metadata is currently `None`..." + raise ValueError(err_str) + + if embedding_params is None: + embedding_params = [('tSNE', 9), ('UMAP', 15), + ('tSNE', 350), ('UMAP', 400)] + + if EMBEDR_params is None: + EMBEDR_params = {} + + if color_by_cluster: + if label_params is None: + label_params = {} + + [labels, + label_counts, + long_labels, + lab_2_idx_map, + label_cmap] = putl.process_categorical_label(metadata, + label_name, + **label_params) + + if labels_2_hl is None: + labels_2_hl = label_counts.index.values[:10] + + if label_colors is None: + label_colors = [label_cmap[lab_2_idx_map[ll]] if ll in labels_2_hl + else 'lightgrey'for ll in labels] + + if label_sizes is None: + label_sizes = [3 if ll in labels_2_hl else 1 for ll in labels] + + elif label_sizes is None: + label_sizes = 3 * np.ones((len(X))) + + if grid_params is None: + grid_params = {} + + [fig, + back_axis, + main_gs, + main_axes] = _make_figure_grid(n_rows=n_rows, n_cols=n_cols, + **grid_params) + + if EMBEDR_params is None: + EMBEDR_params = {'verbose': 1} + + for algNo, (alg, param) in enumerate(embedding_params): + print(f"Plotting data embedded by {alg} (param = {param})") + + if alg.lower() in ['tsne', 't-sne']: + embObj = EMBEDR(X=X, + perplexity=param, + DRA='tsne', + n_data_embed=1, + n_jobs=-1, + project_name=project_name, + project_dir=project_dir) + Y, _ = embObj.get_tSNE_embedding(X) + kEff = human_round(embObj.kEff) + title = f"t-SNE: " + r"$k_{Eff} \approx $" + f"{kEff:.0f}" + + if not color_by_cluster: + labels = np.log10(embObj._kEff) + + if alg.lower() in ['umap']: + embObj = EMBEDR(X=X, + n_neighbors=param, + DRA='umap', + n_data_embed=1, + n_jobs=-1, + project_name=project_name, + project_dir=project_dir, + **EMBEDR_params) + Y, _ = embObj.get_UMAP_embedding(X) + title = f"UMAP: " + r"$k = $" + f"{param:.0f}" + + if not color_by_cluster: + kNN_graph = embObj.get_kNN_graph(X) + labels = np.log10(kNN_graph.kNN_dst[:, param - 1]) + + rowNo = int(algNo / n_cols) + colNo = int(algNo % n_cols) + axis = main_axes[rowNo][colNo] + + if color_by_cluster: + hax = _add_plot_colored_by_cluster(Y[0], labels, axis, + label_colors, label_sizes, + labels_2_hl, + scatter_alpha=scatter_alpha) + + else: + hax = _add_plot_colored_by_var(Y[0], labels, axis, label_sizes, + scatter_alpha=scatter_alpha) + + cax = fig.colorbar(hax, ax=axis, pad=-0.002, + drawedges=False) + + c_ticks = cax.get_ticks() + cax.set_ticks(c_ticks, ) + c_ticklabels = [f"{int(human_round(10**tck))}" for tck in c_ticks] + cax.set_ticklabels(c_ticklabels) + cax.ax.yaxis.set_tick_params(pad=-0.5) + + if alg.lower() == 'tsne': + cax.set_label(r"Effective Nearest Neighbors, $k_{Eff}$", + labelpad=-3) + if alg.lower() == 'umap': + cax.set_label(r"Distance to $k^{th}$ Neighbor", + labelpad=-3) + + cax.solids.set_edgecolor('face') + fig.canvas.draw() + + axis.set_title(title, fontsize=title_size, pad=title_pad) + ylim = axis.get_ylim() + axis.set_ylim(ylim[0], ylim[1] + 0.1 * (ylim[1] - ylim[0])) + + if color_by_cluster: + text_off = 0 + text_h = 0 + + ax_width = back_axis.get_window_extent().width + ax_height = back_axis.get_window_extent().height / fig.dpi + + pad = 3 + + rect_width = axis.get_window_extent().width / ax_width + + for lNo, lab in enumerate(labels_2_hl): + + if lNo < 5: + x_loc = rect_width / 2 + rect_x = 0 + else: + x_loc = 1 - (rect_width / 2) + rect_x = 1 - rect_width + + if "Slamf1-Negative" in lab: + lab = " ".join(lab.split(" Multipotent ")) + label_str = f"{lNo + 1}: " + lab.title() + + bb = axis.text(x_loc, -0.013 - (lNo % 5) * text_h, + label_str, ha='center', va='top', fontsize=10, + transform=back_axis.transAxes) + + text_h = (bb.get_size() + 2 * pad) / 72. / ax_height + + # if (cNo % 2) == 1: + # text_off -= text_h + + # rect_y = (int(cNo / 2) + 1) * text_h + rect_y = ((lNo % 5) + 1) * text_h + + label_color = label_cmap[lab_2_idx_map[lab]] + + rect = plt.Rectangle((rect_x, -rect_y), + width=rect_width, + height=text_h, + transform=back_axis.transAxes, + zorder=3, + fill=True, + facecolor=label_color, + clip_on=False, + alpha=0.8, + edgecolor='0.8') + + back_axis.add_patch(rect) + + fig.tight_layout() + + if add_panel_numbers: + + letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + for rowNo in range(n_rows): + for colNo in range(n_cols): + + axis = main_axes[rowNo][colNo] + letter = letters[rowNo * n_cols + colNo] + + _ = putl.add_panel_number(axis, letter, edge_pad=10) + + fig.tight_layout() + + if color_by_cluster: + fig_base = project_name + f"_{data_name}" + "_ColoredByCluster" + fig_pad = 0.5 if fig_pad is not None else fig_pad + else: + fig_base = project_name + f"_{data_name}" + "_ColoredByVariable" + fig_pad = 3 if fig_pad is not None else fig_pad + + print(fig_base) + ## SAVE FIGURE HERE + putl.save_figure(fig, + fig_base, + fig_dir=fig_dir, + tight_layout_pad=fig_pad) + + return fig + + + + + + + + + + + + + + + + + + + + + + + + + + + + + def make_figure(X, cluster_labels, clusters_2_label=None, label_colors=None, diff --git a/EMBEDR/plotting_utility.py b/EMBEDR/plotting_utility.py index 6b51954..b1fafde 100644 --- a/EMBEDR/plotting_utility.py +++ b/EMBEDR/plotting_utility.py @@ -220,35 +220,42 @@ def add_panel_number(axis, ############################################################################### -## Functions for Colorbars +## Functions for Figure Aesthetics ############################################################################### -def make_categ_cmap(change_points=[0, 2, 3, 4, 5], - categorical_cmap=None, - cmap_idx=None, - cmap_dx=0.001, - reverse_last_interval=True): +class CategoricalFadingCMap(object): """Make categorical colormap that fades between colors at specific values. - This function takes in a list of end points + interior points to set as - the edge of regions on a colormap. The function then returns a new - continuous colormap that transitions between these regions (fades to white, - then changes colors). + This class creates a blended categorical-continuous colormap in which + designated colors are faded to black or white in descrete regions. As an + example, this class is used to create the p-value colorbars in the EMBEDR + plotting functions, where different levels of p-values are given different + colors, but within each level, the color also fades as a p-value goes from + one end of the category to the other. This is useful for situations in + which values have discrete bins into which they can be mapped, but we still + want to see the individual variation in points. + + The default arguments are set to that used by EMBEDR for p-value colorbars. Parameters ---------- - change_points: Iterable (optional) + change_points: Iterable (optional, default=[0, 2, 3, 4, 5]) The values at which to change between categories. The end-points (the max and min values to be shown on the colormap) must be supplied so that if 4 categories are desired, `change_points` must contain 4 + 1 - values. - - categorical_cmap: Seaborn colormap object (optional) - A categorical colormap (list of tuples) to which the intervals between - `change_points` will be mapped. - - cmap_idx: Iterable (optional) + values. These values are in units of the measurement being used to + assign colors to points, i.e. if height is being used to color points, + then change_points might be [1ft, 2ft, 4ft, 6ft], so that there are 3 + height categories: 1-2ft, 2-4ft, and 4-6ft. All values outside this + range will be mapped to the minimum and maximum color of the range + (i.e. a 7ft person would have the same color as a 6ft person). + + base_cmap: Union[str, Iterable of tuples] (optional, default='colorblind') + A categorical colormap (list of tuples) or the name of a Seaborn + colormap to which the intervals between `change_points` will be mapped. + + cmap_idx: Iterable (optional, default=[4, 0, 3, 2]) A list of indices that maps the colors in the colormap to the correct interval. This allows for preset colormaps to be remapped by changing `cmap_idx` from [0, 1, 2, 3] to [2, 3, 1, 0], for example. @@ -256,72 +263,181 @@ def make_categ_cmap(change_points=[0, 2, 3, 4, 5], cmap_dx: float (optional, default=0.001) Interval at which to interpolate colors. Smaller will make the colormap seem more continuous, but may have trouble rendering on some - computers. + computers. If `cmap_dx` < 1, this will be interpreted as an interval + size in the units of `change_points`. If `cmap_dx` > 1, this will + be interpreted as the number of interpolation intervals to calculate + across `change_points`. + + cmap_kwds: dict (optional, default={}) + Other keywords to pass to `matplotlib.colors.ListedColormap` object. + + max_divergence: float (optional, default=0.75) + Maximal distance between white/black and the category color to allow + in each region. Setting to 0 will keep the colors constant in each + category, while setting to 1 will allow the colors to fade entirely to + black or white. reverse_last_interval: bool (optional, default=True) Flag indicating whether to reverse the interpolation direction on the last interval. This can be useful to set up a maximal contrast in one part of the colormap. + + fade_to_white: bool (optional, default=True) + Flag indicating whether the colors should fade to white or black within + a category. """ - if categorical_cmap is None: - import seaborn as sns - categorical_cmap = sns.color_palette('colorblind') + def __init__(self, + change_points=[0, 2, 3, 4, 5], + base_cmap='colorblind', + cmap_idx=None, + cmap_dx=0.001, + cmap_kwds=None, + max_divergence=0.75, + reverse_last_interval=True, + fade_to_white=True): + + self.change_points = change_points + self.base_cmap = base_cmap + self.cmap_idx = cmap_idx + self.cmap_dx = cmap_dx + self.cmap_kwds = cmap_kwds + self.max_divergence = max_divergence + + ## Optional flags + self.reverse_last_interval = reverse_last_interval + self.fade_to_white = fade_to_white + + self._validate_parameters() + + self.cmap, self.cnorm = self.make_cmap() + + def _validate_parameters(self): + + try: + self.change_points = np.unique([el for el in self.change_points]) + self.change_points = np.sort(self.change_points).squeeze() + self.n_categ = len(self.change_points) - 1 + except TypeError as te: + err_str = "Input argument `change_points` is not iterable!" + raise TypeError(err_str) + + if isinstance(self.base_cmap, str): + self.base_cmap = sns.color_palette(self.base_cmap) + else: + try: + _ = self.base_cmap[0] + except TypeError as err: + err_str = err.args[0] + f"\n\n\t Input `base_cmap` could" + err_str += f" not be indexed (_ = cmap[0] failed). Make sure" + err_str += f" `base_cmap` is either a subscriptable colormap" + err_str += f" or an iterable containing colors from which to" + err_str += f" create the categorical colormap." + raise TypeError(err_str) + self._n_base_colors = len(self.base_cmap) + + if self.cmap_idx is None: + self.cmap_idx = [4, 0, 3, 2] + list(range(4, self.n_categ)) + + try: + [el for el in self.cmap_idx] + assert len(self.cmap_idx) == self.n_categ + except TypeError as te: + err_str = "Input argument `change_points` is not iterable!" + raise TypeError(err_str) + except AssertionError as ae: + err_str = f"Input size of `cmap_idx` does not map the number of" + err_str += f" categories indicated by `change_points`" + err_str += f" ({self.n_categ} != {len(self.cmap_idx)}). There must" + err_str += f" be one index in `cmap_idx` for each category." + raise ValueError(err_str) - ## Set the list of indices to use from the colormap - if cmap_idx is None: - cmap_idx = [4, 0, 3, 2] + list(range(4, len(change_points) - 1)) + try: + self.cmap_dx = float(self.cmap_dx) + assert self.cmap_dx > 0 + except (AssertionError, ValueError) as err: + err_str = f"Input argument `cmap_dx` must be a positive float." + raise ValueError(err_str) + + if self.cmap_dx > 1: + self.cmap_dx = ((self.change_points[-1] - self.change_points[0]) + / self.cmap_dx) + + if self.cmap_kwds is None: + self.cmap_kwds = {'name': "EMBEDR p-Values (-log10)"} + err_str = f"Input argument `cmap_kwds` must be a dictionary!" + assert isinstance(self.cmap_kwds, dict), err_str + self.cmap_kwds = self.cmap_kwds.copy() + + try: + self.max_divergence = float(self.max_divergence) + assert 1 > self.max_divergence > 0 + except (AssertionError, ValueError) as err: + err_str = f"Input argument `max_divergence` must be in [0, 1]." + raise ValueError(err_str) - ## Set the base colors for regions of the colormap - colors = [categorical_cmap[idx] for idx in cmap_idx] + self.reverse_last_interval = bool(self.reverse_last_interval) + self.fade_to_white = bool(self.fade_to_white) - ## Make an appropriate grid of points on which to set colors. - color_grid = [] - for intNo, end in enumerate(change_points[1:]): - color_grid += list(np.arange(change_points[intNo], end, cmap_dx)) - color_grid += [change_points[-1]] - color_grid = np.sort(np.unique(np.asarray(color_grid)).squeeze()) + def make_cmap(self): - ## Initialize the RGB+ array. - out_colors = np.ones((len(color_grid), 4)) + ## Get the base colors of the categories. + self.base_colors = [self.base_cmap[idx] for idx in self.cmap_idx] - ## Iterate through the grid, setting interpolated colors for each region. - start_idx = 0 - for intNo, start in enumerate(change_points[:-1]): + ## Make an appropriate grid of points on which to set colors. + color_grid = [] + for start, end in zip(self.change_points[:-1], self.change_points[1:]): + color_grid += list(np.arange(start, end, self.cmap_dx)) + color_grid += [self.change_points[-1]] + ## This checks that we didn't double up any grid points. + color_grid = np.sort(np.unique(np.asarray(color_grid)).squeeze()) - ## Get the number of grid points in this interval - N_ticks = int((change_points[intNo + 1] - start) / cmap_dx) - ## If it's the last interval, add an extra. - if intNo == (len(change_points) - 2): - N_ticks += 1 + ## Initialize the RGB+ array. + final_colors = np.ones((len(color_grid), 4)) - ## Iterate through each of the RGB values. - for jj in range(3): + ## Iterate through the category boundaries, setting interpolated colors + ## for each category. + cat_idx = 0 + for catNo, [start, end] in enumerate(zip(self.change_points[:-1], + self.change_points[1:])): + ## Get the number of grid points in this interval + n_ticks = int((end - start) / self.cmap_dx) + ## If it's the last interval, add an extra. + if end == self.change_points[-1]: + n_ticks += 1 - ## Base color for each interval - base_color = colors[intNo][jj] + ## Iterate through each of the RGB values. + for jj in range(3): - ## Maximum divergence from the base color. - upper_bound = 0.75 * (1 - base_color) + base_color + ## Base color for each interval + base_color = self.base_colors[catNo][jj] - ## Interpolated grid for the interval. - intv_color_grid = np.linspace(base_color, upper_bound, N_ticks) + ## Maximum divergence from the base color. + top = 1 if self.fade_to_white else 0 + color_diff = self.max_divergence * (top - base_color) + upper_bound = color_diff + base_color - ## If we're in the last interval, can reverse the direction - if (intNo == (len(change_points) - 2)) and reverse_last_interval: - intv_color_grid = intv_color_grid[::-1] + ## Interpolated grid for the interval. + intv_color_grid = np.linspace(base_color, upper_bound, n_ticks) - ## Set the colors! - out_colors[start_idx:start_idx + N_ticks, jj] = intv_color_grid + ## If we're in the last interval, can reverse the direction + if catNo == (self.n_categ - 1): + if self.reverse_last_interval: + intv_color_grid = intv_color_grid[::-1] - start_idx += N_ticks + ## Set the colors! + final_colors[cat_idx:cat_idx + n_ticks, jj] = intv_color_grid - ## Convert the grids and colors to matplotlib colormaps. - import matplotlib.colors as mcl - out_cmap = mcl.ListedColormap(out_colors) - out_cnorm = mcl.BoundaryNorm(color_grid, out_cmap.N) + cat_idx += n_ticks - return out_cmap, out_cnorm + ## Convert the grids and colors to matplotlib colormaps. + cmap = mcl.ListedColormap(final_colors, **self.cmap_kwds) + cmap.set_extremes(bad='lightgrey', + under=self.base_colors[0], + over=self.base_colors[-1]) + cnorm = mcl.BoundaryNorm(color_grid, cmap.N) + + return cmap, cnorm def make_seq_cmap(color_1, color_2, n_colors=10): @@ -464,3 +580,29 @@ def wrap_strings(str_arr, line_len=26): return out_list + +def process_categorical_label(metadata, label, cmap='colorblind', + alphabetical_sort=False): + + ## Extract the raw labels + raw_labels = metadata[label].values.copy() + + ## Get the unique labels and their counts + label_counts = metadata[label].value_counts() + unique_labels = label_counts.index.values + + if alphabetical_sort: + unique_labels = np.sort(unique_labels) + + ## Make some nice long labels. + long_labels = np.asarray([f"{ll} (N = {label_counts.loc[ll]:d})" + for ll in unique_labels]) + + ## Make a colormap + if isinstance(cmap, str): + label_cmap = sns.color_palette(cmap, len(unique_labels)) + + lab_2_idx_map = {ll: ii for ii, ll in enumerate(unique_labels)} + + return raw_labels, label_counts, long_labels, lab_2_idx_map, label_cmap + diff --git a/EMBEDR/utility.py b/EMBEDR/utility.py index 4dda9a2..ebbd5ff 100644 --- a/EMBEDR/utility.py +++ b/EMBEDR/utility.py @@ -1,5 +1,7 @@ from collections import Counter import numpy as np +import os +import pandas as pd import scanpy as sc from time import time @@ -49,8 +51,9 @@ def load_data(data_name, X = np.loadtxt(data_path).astype(dtype) if load_metadata: - metadata_path = path.join(data_dir, "mnist2500_labels.txt") + metadata_path = os.path.join(data_dir, "mnist2500_labels.txt") metadata = np.loadtxt(metadata_path).astype(int) + metadata = pd.DataFrame(metadata, columns=['label']) elif data_name.lower() in tabula_muris_tissues: @@ -61,6 +64,9 @@ def load_data(data_name, data_path = os.path.join(data_dir, data_file) X = sc.read_h5ad(data_path) + metadata = X.obs.copy() + + X = X.obsm['X_pca'] elif data_name.lower() == "ATAC": @@ -71,8 +77,14 @@ def load_data(data_name, data_path = os.path.join(data_dir, data_file) X = sc.read_h5ad(data_path) + metadata = X.obs.copy() + + X = X.obsm['lsi'] - return X, metadata + if load_metadata: + return X, metadata + else: + return X ###############################################################################