Skip to content

Commit

Permalink
Added more plotting scripts. Moved plotting routines to plots folder.
Browse files Browse the repository at this point in the history
  • Loading branch information
ejohnson643 committed Nov 6, 2021
1 parent 651f2e0 commit 31ebd8a
Show file tree
Hide file tree
Showing 3 changed files with 431 additions and 84 deletions.
170 changes: 95 additions & 75 deletions EMBEDR/embedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,20 +1860,17 @@ def calculate_pValues(self):
raise ValueError(err_str)

def plot(self,
ax=None,
cax=None,
show_cbar=True,
embed_2_show=0,
plot_type='pvalue',
metadata=None,
is_categorical=False,
plot_data=True,
cbar_ticks=None,
cbar_ticklabels=None,
pVal_clr_change=[0, 1, 2, 3, 4],
scatter_s=5,
scatter_alpha=0.4,
scatter_kwds={},
text_kwds={},
cbar_kwds={},
cite_EMBEDR=True):
embed_2_show=0,
fig=None,
axis=None,
cbar_ax=None,
show_cbar=True,
cite_EMBEDR=True,
**plot_kwds):
"""Generates scatter plot of embedded data colored by EMBEDR p-value
Parameters
Expand Down Expand Up @@ -1938,65 +1935,87 @@ def plot(self,
> fig, ax = plt.subplots(1, 1)
> ax = embedr_obj.plot(ax=ax)
"""
import matplotlib.pyplot as plt

fig = plt.gcf()
if ax is None:
ax = fig.gca()

if plot_data:
Y = self.data_Y[embed_2_show]
plot_Y = self.data_Y[embed_2_show]
else:
Y = self.null_Y[embed_2_show]

pVal_cmap = putl.CategoricalFadingCMap(change_points=pVal_clr_change)

pVals = -np.log10(self.pValues)

sort_idx = np.argsort(pVals)

h_ax = ax.scatter(*Y[sort_idx].T,
s=scatter_s,
c=pVals[sort_idx],
cmap=pVal_cmap.cmap,
norm=pVal_cmap.cnorm,
alpha=scatter_alpha,
**scatter_kwds)

if cite_EMBEDR:
_ = ax.text(0.02, 0.02,
"Made with the EMBEDR package.",
fontsize=6,
transform=ax.transAxes,
ha='left',
va='bottom',
**text_kwds)

if show_cbar:
cbar_ax = fig.colorbar(h_ax,
ax=ax,
cax=cax,
boundaries=pVal_cmap.cnorm.boundaries,
ticks=[],
**cbar_kwds)
cbar_ax.ax.invert_yaxis()

if cbar_ticks is None:
cbar_ticks = pVal_cmap.change_points
cbar_ax.set_ticks(cbar_ticks)

if cbar_ticklabels is None:
cbar_ticklabels = ['1',
'0.1',
r"$10^{-2}$",
r"$10^{-3}$",
r"$10^{-4}$"]
cbar_ax.set_ticklabels(cbar_ticklabels)

cbar_ax.ax.tick_params(length=0)
cbar_ax.ax.set_ylabel("EMBEDR p-Value")

return ax
plot_Y = self.null_Y[embed_2_show]

if plot_type.lower() in ['pvalue', 'min_pvalue']:

from EMBEDR.plots.embedr_scatterplots import Scatter_by_pValue

if plot_type.lower() == 'min_pvalue':
if not hasattr(self, 'opt_min_pVals'):
err_str = f"Current EMBEDR object does not contain"
err_str += f" minimal p-Values resulting from a"
err_str += f" hyperparameter sweep. (Current object might"
err_str += f" not be a cell-optimal embedding generated by"
err_str += f" an EMBEDR_sweep.) If you want to plot"
err_str += f" according to p-Value, use `plot_type`="
raise TypeError(err_str + f"'pvalue'.")

pValues = self.opt_min_pVals
else:
pValues = self.pValues

plotObj = Scatter_by_pValue(plot_Y,
pValues,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

elif plot_type.lower() in ['ees']:

from EMBEDR.plots.embedr_scatterplots import Scatterplot

if plot_data:
ees_vals = np.median(self.data_EES, axis=0)
else:
ees_vals = np.median(self.null_EES, axis=0)

plotObj = Scatterplot(plot_Y,
ees_vals,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

elif (metadata is not None) and (plot_type in metadata):

if is_categorical:
from EMBEDR.plots.embedr_scatterplots import Scattergory

plotObj = Scattergory(plot_Y,
plot_type,
metadata,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

else:
from EMBEDR.plots.embedr_scatterplots import Scatterplot

labels = metadata[plot_type].values

plotObj = Scatterplot(plot_Y,
labels,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

return plotObj.plot()


class EMBEDR_sweep(object):
Expand Down Expand Up @@ -2468,12 +2487,12 @@ def get_optimal_hyperparameters(self):
if self.verbose >= 5:
print(f"Returning optimal '{self.sweep_type}' values!")

return opt_sweep_values
return pVal_arr.min(axis=0), opt_sweep_values

def fit_samplewise_optimal(self):

try:
opt_hp_vals = self.get_optimal_hyperparameters()
opt_min_pVals, opt_hp_vals = self.get_optimal_hyperparameters()
except AttributeError:
err_str = f"A hyperparameter sweep must be run before a sample"
err_str += f"-wise optimal embedding can be generated!"
Expand Down Expand Up @@ -2509,11 +2528,14 @@ def fit_samplewise_optimal(self):

optObj.fit(self.data_X)

optObj.opt_min_pVals = opt_min_pVals.copy()

self.opt_obj = optObj
self.opt_embed = optObj.data_Y[:]
self.opt_embed_data_EES = optObj.data_EES[:]
self.opt_embed_null_EES = optObj.null_EES[:]
self.opt_embed_pValues = optObj.pValues[:]
self.opt_min_pVals = opt_min_pVals.copy()

def _get_kEff_from_hp(self, hp_value):
sort_idx = np.argsort(self.sweep_values)
Expand Down Expand Up @@ -2549,7 +2571,7 @@ def sweep_boxplot(self, sweep_type='pvalues', **kwargs):
kEff_dict=self.kEff,
verbose=self.verbose > 3,
**kwargs)
axis = plotObj.make_plot()
axis = plotObj.plot()

if sweep_type.lower() == 'ees':
from EMBEDR.plots.sweep_boxplots import SweepBoxplot_EES
Expand All @@ -2560,7 +2582,7 @@ def sweep_boxplot(self, sweep_type='pvalues', **kwargs):
kEff_dict=self.kEff,
verbose=self.verbose > 3,
**kwargs)
axis = plotObj.make_plot()
axis = plotObj.plot()

return axis

Expand Down Expand Up @@ -2590,8 +2612,6 @@ def sweep_lineplot(self, sweep_type='pvalues', metadata=None, labels=None,
xticks = np.clip(xticks, 0, n_hp - 1)
xticks = np.asarray(xticks).astype(int)

print(xticks)

if len(xticklabels) == 0:
xtlabs = np.asarray([self.kEff[hp_values[idx]]
for idx in xticks])
Expand Down
Loading

0 comments on commit 31ebd8a

Please sign in to comment.