Skip to content

Commit

Permalink
Added a bunch of plotting routines to the EMBEDR_sweep class...
Browse files Browse the repository at this point in the history
  • Loading branch information
ejohnson643 committed Nov 5, 2021
1 parent 385c37b commit 651f2e0
Show file tree
Hide file tree
Showing 5 changed files with 1,175 additions and 239 deletions.
320 changes: 83 additions & 237 deletions EMBEDR/embedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _validate_parameters_without_data(self):
self._keep_affmats = bool(self._keep_affmats)
self.rs = check_random_state(self.rs)
self._seed = self.rs.get_state()[1][0]
self._null_seed = self._seed
try:
self.verbose = float(self.verbose)
except ValueError as e:
Expand Down Expand Up @@ -560,8 +561,6 @@ def _fit(self, null_fit=False):
self.n_components))
self.null_EES = np.zeros((self.n_null_embed, self.n_samples))

self._null_seed = self._seed

for nNo in range(self.n_null_embed):

if self.verbose >= 1:
Expand Down Expand Up @@ -1153,7 +1152,7 @@ def get_tSNE_embedding(self,
aff_mat = self._recalculate_P(aff_mat, kNN_graph)

## Initialize the t-SNE object.
embObj = self._initialize_tSNE_embed(aff_mat)
embObj = self._initialize_tSNE_embed(X, aff_mat)

## Get a locally-normed and asymmetric affinity matrix for EES...
local_aff_mat = self._get_asym_local_affmat(X, kNN_graph=kNN_graph,
Expand Down Expand Up @@ -1269,7 +1268,7 @@ def get_tSNE_embedding(self,

return emb_Y, emb_EES

def _initialize_tSNE_embed(self, affObj):
def _initialize_tSNE_embed(self, X, affObj):

if self.verbose >= 3:
print(f"\nInitializing t-SNE Embedding...")
Expand All @@ -1281,7 +1280,11 @@ def _initialize_tSNE_embed(self, affObj):
verbose=self.verbose,
**self.DRA_params)

embObj.initialize_embedding(affObj)
if 'initialization' in self.DRA_params:
if self.DRA_params['initialization'] == 'pca':
embObj.initialize_embedding(X)
else:
embObj.initialize_embedding(affObj)

return embObj

Expand All @@ -1305,7 +1308,7 @@ def load_tSNE_embedding(self,

## Initialize the embedding object.
if embObj is None:
embObj = self._initialize_tSNE_embed(aff_mat)
embObj = self._initialize_tSNE_embed(X, aff_mat)

## Check for the existence of loaded kNN graphs.
if "Embed_tSNE" not in self.project_hdr:
Expand Down Expand Up @@ -2481,7 +2484,8 @@ def fit_samplewise_optimal(self):
perp = opt_hp_vals
else:
knn = opt_hp_vals
optObj = EMBEDR(perplexity=perp,
optObj = EMBEDR(X=self.data_X,
perplexity=perp,
n_neighbors=knn,
kNN_metric=self.kNN_metric,
kNN_alg=self.kNN_alg,
Expand Down Expand Up @@ -2535,245 +2539,87 @@ def _get_hp_from_kEff(self, kEff):

return interpolate(x_coords, y_coords, np.asarray(nn)).squeeze()

def plot_sweep(self,
fig=None,
gridspec=None,
box_widths=None,
box_positions=None,
box_notch=True,
box_bootstrap=100,
box_whiskers=(1, 99),
box_color='grey',
box_fliers=None,
box_props=None,
box_hl_color='grey',
box_hl_props=None,
values_2_highlight=None,
xLabel_idx=None,
xLabel=None,
xLabel_size=16,
xLim=None,
pVal_cmap=None,
fig_size=(12, 5),
fig_pad=0.4,
fig_ppad=0.01,
show_borders=False,
bot_wpad=0.0,
bot_hpad=0.0,
cax_ticklabels=None,
cax_width_frac=1.3,
cax_w2h_ratio=0.1):
"""Generates scatter plot of embedded data colored by EMBEDR p-value
Parameters
----------
"""

import matplotlib.gridspec as gs
import matplotlib.pyplot as plt

hp_array = np.sort(self.sweep_values)

if box_fliers is None:
box_fliers = {'marker': ".",
'markeredgecolor': box_color,
'markersize': 2,
'alpha': 0.5}

if box_props is None:
box_props = {'alpha': 0.5,
'color': box_color,
'fill': True}

if box_hl_props is None:
box_hl_props = box_props.copy()
box_hl_props.update({"alpha": 0.9, "color": box_hl_color})

if values_2_highlight is None:
values_2_highlight = []

box_patches = ['boxes', 'whiskers', 'fliers', 'caps', 'medians']

if fig is None:
fig = plt.figure(figsize=fig_size)

if gridspec is None:
gridspec = fig.add_gridspec(1, 1)

if pVal_cmap is None:
pVal_cmap = putl.CategoricalFadingCMap()

## Set up large axes
spine_alpha = 1 if show_borders else 0
bot_ax = fig.add_subplot(gridspec[0])
bot_ax = putl.make_border_axes(bot_ax, xticks=[], yticks=[],
spine_alpha=spine_alpha)

## Set up floating bottom gridspec
bot_gs = gs.GridSpec(nrows=1, ncols=1,
wspace=bot_wpad, hspace=bot_hpad)

ax = putl.make_border_axes(fig.add_subplot(bot_gs[0]),
yticklabels=[],
yticks=-np.sort(pVal_cmap.change_points),
spine_alpha=1)

putl.update_tight_bounds(fig, bot_gs, gridspec[0], w_pad=bot_wpad,
h_pad=bot_hpad, fig_pad=fig_pad)

hl_boxes = {}
hl_idx = []
for hpNo, hpVal in enumerate(hp_array):

if box_widths is not None:
try:
box_wid = box_widths[hpNo]
except TypeError as err:
box_wid = box_widths
else:
box_wid = 0.8

if box_positions is not None:
try:
box_pos = [box_positions[hpNo]]
except TypeError as err:
box_pos = [box_positions]
else:
box_pos = [hpNo]
def sweep_boxplot(self, sweep_type='pvalues', **kwargs):
"""Generates boxplots of specified values vs swept hyperparameter."""
if sweep_type.lower() == 'pvalues':
from EMBEDR.plots.sweep_boxplots import SweepBoxplot_pValues

plotObj = SweepBoxplot_pValues(np.sort(self.sweep_values),
self.pValues,
kEff_dict=self.kEff,
verbose=self.verbose > 3,
**kwargs)
axis = plotObj.make_plot()

if sweep_type.lower() == 'ees':
from EMBEDR.plots.sweep_boxplots import SweepBoxplot_EES

EES_dict = {'data':self.data_EES, 'null':self.null_EES}
plotObj = SweepBoxplot_EES(np.sort(self.sweep_values),
EES_dict,
kEff_dict=self.kEff,
verbose=self.verbose > 3,
**kwargs)
axis = plotObj.make_plot()

return axis

def sweep_lineplot(self, sweep_type='pvalues', metadata=None, labels=None,
params_2_highlight=None, fig=None, axis=None,
xticks=[], xticklabels=[], **kwargs):
"""Generates lineplots of each cell vs swept hyperparameters."""

hp_values = np.sort(self.sweep_values)
n_hp = len(hp_values)

if params_2_highlight is None:
hp_2_hl = []
else:
hp_2_hl = params_2_highlight
hp_2_hl_idx = np.array([ii for ii, hp in enumerate(hp_values)
if hp in hp_2_hl]).astype(int)

if hpVal in values_2_highlight:
box_pps = box_hl_props.copy()
box_col = box_color
hl_idx.append(hpNo)
if len(xticks) < 1:
if len(hp_2_hl) > 0:
xticks = [0] + hp_2_hl_idx.tolist() + [n_hp - 1]
else:
box_pps = box_props.copy()
box_col = box_hl_color

box = ax.boxplot(np.log10(self.pValues[hpVal]),
widths=box_wid,
positions=box_pos,
notch=box_notch,
bootstrap=box_bootstrap,
patch_artist=True,
whis=box_whiskers,
boxprops=box_pps,
flierprops=box_fliers)

for item in box_patches:
plt.setp(box[item], color=box_col)

if hpVal in values_2_highlight:
hl_boxes[hpVal] = box['boxes'][0]

if xLabel_idx is None:
if values_2_highlight:
xLabel_idx = [0] + hl_idx + [hpNo]
else:
if len(hp_array) <= 5:
xLabel_idx = np.arange(len(hp_array))
if n_hp <= 5:
xticks = np.arange(n_hp)
else:
xLabel_idx = np.linspace(0, len(hp_array), 5)
xLabel_idx = human_round(xLabel_idx)
xLabel_idx = np.asarray(xLabel_idx).astype(int)

ax.set_xticks(xLabel_idx)
xticks = [f"{int(self.kEff[hp_array[idx]])}" for idx in xLabel_idx]
xticks = human_round(np.asarray(xticks).squeeze())
ax.grid(which='major', axis='x', alpha=0)
ax.set_xticklabels(xticks)

if xLim is None:
xLim = [-1, len(hp_array)]

ax.set_xlabel(r"$ k_{Eff}$", fontsize=xLabel_size, labelpad=0)
ax.set_xlim(*xLim)

# ax.set_yticks(-np.sort(pVal_cmap.change_points))
# ax.set_yticklabels([])

ax.set_ylim(-pVal_cmap.change_points.max(),
-pVal_cmap.change_points.min())

ax.tick_params(pad=-3)

## Update the figure again...
putl.update_tight_bounds(fig, bot_gs, gridspec[0], w_pad=bot_wpad,
h_pad=bot_hpad, fig_pad=fig_pad)

## Colorbar parameters
if cax_ticklabels is None:
cax_ticklabels = [f"{10.**(-cp):.1e}"
for cp in pVal_cmap.change_points]

inv_ax_trans = ax.transAxes.inverted()
fig_trans = fig.transFigure

## Convert from data to display
min_pVal = np.min([np.log10(self.pValues[hp].min())
for hp in self.sweep_values])
min_pVal = np.min([min_pVal, -pVal_cmap.change_points.max()])
max_pVal = np.max([np.log10(self.pValues[hp].max())
for hp in self.sweep_values])
max_pVal = np.min([max_pVal, -pVal_cmap.change_points.min()])
min_pVal_crds = ax.transData.transform([xLim[0], min_pVal])
max_pVal_crds = ax.transData.transform([xLim[0], max_pVal])
xticks = np.unique(np.linspace(0, n_hp, 5))
xticks = np.clip(xticks, 0, n_hp - 1)
xticks = np.asarray(xticks).astype(int)

# print(f"min_pVal_crds: {min_pVal_crds}")
# print(f"max_pVal_crds: {max_pVal_crds}")
print(xticks)

## Convert from display to figure coordinates
cFigX0, cFigY0 = fig.transFigure.inverted().transform(min_pVal_crds)
cFigX1, cFigY1 = fig.transFigure.inverted().transform(max_pVal_crds)
if len(xticklabels) == 0:
xtlabs = np.asarray([self.kEff[hp_values[idx]]
for idx in xticks])
xticklabels = human_round(xtlabs).squeeze().astype(int)

# print(f"cFig0: {cFigX0:.4f}, {cFigY0:.4f}")
# print(f"cFig1: {cFigX1:.4f}, {cFigY1:.4f}")
if sweep_type.lower() == 'pvalues':
values_dict = self.pValues

cFig_height = np.abs(cFigY1 - cFigY0)
cFig_width = cax_w2h_ratio * cFig_height
if (metadata is None) or (labels is None):
from EMBEDR.plots.sweep_lineplots import sweep_lineplot

# print(f"The color bar will be {cFig_width:.4f} x {cFig_height:.4f}")
axis = sweep_lineplot(hp_values,
values_dict,
fig=fig,
axis=axis,
# xticks=xticks,
xticklabels=xticklabels, **kwargs)

cAxX0, cAxY0 = cFigX0 - cax_width_frac * cFig_width, cFigY0
cAxX1, cAxY1 = cAxX0 + cFig_width, cFigY0 + cFig_height

## Convert from Figure back into Axes
[cAxX0,
cAxY0] = inv_ax_trans.transform(fig_trans.transform([cAxX0, cAxY0]))
[cAxX1,
cAxY1] = inv_ax_trans.transform(fig_trans.transform([cAxX1, cAxY1]))

# print(f"cAx0: {cAxX0:.4f}, {cAxY0:.4f}")
# print(f"cAx1: {cAxX1:.4f}, {cAxY1:.4f}")

cAx_height = np.abs(cAxY1 - cAxY0)
cAx_width = np.abs(cAxX1 - cAxX0)

# print(f"The color bar will be {cAx_width:.4f} x {cAx_height:.4f}")

caxIns = ax.inset_axes([cAxX0, cAxY0, cAx_width, cAx_height])
caxIns = putl.make_border_axes(caxIns, spine_alpha=0)

hax = plt.scatter([], [], c=[], s=[], cmap=pVal_cmap.cmap,
norm=pVal_cmap.cnorm)
cAx = fig.colorbar(hax, cax=caxIns, ticks=[],
boundaries=pVal_cmap.cnorm.boundaries)
cAx.ax.invert_yaxis()

cAx.set_ticks(pVal_cmap.change_points)
cAx.set_ticklabels(cax_ticklabels)
cAx.ax.tick_params(length=0)
cAx.ax.yaxis.set_ticks_position('left')

cAx.ax.set_ylabel(r"EMBEDR $p$-Value",
fontsize=xLabel_size,
labelpad=2)
cAx.ax.yaxis.set_label_position('left')

## Update the figure again...
putl.update_tight_bounds(fig, bot_gs, gridspec[0], w_pad=bot_wpad,
h_pad=bot_hpad, fig_pad=fig_pad)

return ax
else:
from EMBEDR.plots.sweep_lineplots import sweep_lineplot_byCat
axes = sweep_lineplot_byCat(hp_values,
values_dict,
metadata,
labels,
fig=fig,
# xticks=xticks,
xticklabels=xticklabels,
**kwargs)



Loading

0 comments on commit 651f2e0

Please sign in to comment.