Skip to content

Commit

Permalink
Fixed more bugs in plotting scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
ejohnson643 committed Nov 7, 2021
1 parent 170875c commit d7f3dfa
Show file tree
Hide file tree
Showing 6 changed files with 667 additions and 24 deletions.
18 changes: 11 additions & 7 deletions EMBEDR/embedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _validate_parameters_without_data(self):
return

def __str__(self):

out_str = ""
if self.verbose > 1:
out_str += f"\n\n\tEMBEDR Class v{ev}\n" + 35 * "=" + "\n\n"
Expand Down Expand Up @@ -2636,24 +2637,27 @@ def sweep_lineplot(self, sweep_type='pvalues', metadata=None, labels=None,
if (metadata is None) or (labels is None):
from EMBEDR.plots.sweep_lineplots import sweep_lineplot

axis = sweep_lineplot(hp_values,
return sweep_lineplot(hp_values,
values_dict,
fig=fig,
axis=axis,
# xticks=xticks,
xticklabels=xticklabels, **kwargs)
xticks=xticks,
xticklabels=xticklabels,
**kwargs)

else:
from EMBEDR.plots.sweep_lineplots import sweep_lineplot_byCat
axes = sweep_lineplot_byCat(hp_values,
return sweep_lineplot_byCat(hp_values,
values_dict,
metadata,
labels,
fig=fig,
# xticks=xticks,
xticks=xticks,
xticklabels=xticklabels,
**kwargs)



def plot_embedding(self,
param_2_plot='optimal',
plot_type='pvalue',
Expand Down Expand Up @@ -2690,9 +2694,9 @@ def plot_embedding(self,

perp = knn = None
if self.sweep_type == 'perplexity':
perp = param_2_plot
perp = np.ones((self.n_samples)) * param_2_plot
else:
knn = param_2_plot
knn = np.ones((self.n_samples)) * param_2_plot

tmpEmb = EMBEDR(perplexity=perp,
n_neighbors=knn,
Expand Down
3 changes: 2 additions & 1 deletion EMBEDR/plots/EMBEDR_Figure_01v1_DimRed_Zoology.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def EMBEDR_Figure_01(X,
n_data_embed=1,
n_jobs=-1,
project_name=project_name,
project_dir=project_dir)
project_dir=project_dir,
**EMBEDR_params)
Y, _ = embObj.get_tSNE_embedding(X)
kEff = human_round(embObj.kEff)
title = f"t-SNE: " + r"$k_{Eff} \approx $" + f"{kEff:.0f}"
Expand Down
19 changes: 14 additions & 5 deletions EMBEDR/plots/embedr_scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self,
labels,
log_labels=False,
axis=None,
axis_kwds=None,
show_border=True,
scatter_sizes=3,
scatter_alpha=1,
Expand All @@ -32,6 +33,7 @@ def __init__(self,
yticks=None,
yticklabels=None,
ylabel=None,
grid_kwds=None,
label_size=12,
plot_order=None,
title=None,
Expand All @@ -55,6 +57,7 @@ def __init__(self,
self.labels = np.log10(self.labels)

self.axis = axis
self.axis_kwds = {} if axis_kwds is None else axis_kwds.copy()
self.show_border = show_border

self.sct_s = scatter_sizes
Expand All @@ -74,6 +77,8 @@ def __init__(self,
self.ylabel = ylabel
self.label_size = label_size

self.grid_kwds = {} if grid_kwds is None else grid_kwds.copy()

if plot_order is None:
self.sort_idx = np.arange(self.n_samples)
elif plot_order.lower() == 'asc':
Expand Down Expand Up @@ -107,10 +112,18 @@ def plot(self, **kwargs):
self.fig = self.axis.figure

spine_alpha = 1 if self.show_border else 0
self.axis = putl.make_border_axes(self.axis, spine_alpha=spine_alpha)
self.axis = putl.make_border_axes(self.axis,
spine_alpha=spine_alpha,
xticks=self.xticks,
yticks=self.yticks,
xticklabels=self.xticklabels,
yticklabels=self.xticklabels,
**self.axis_kwds)

self.axis = self._plot(**kwargs)

self.axis.grid(**self.grid_kwds)

if self.cite_EMBEDR:
if 'transform' not in self.text_kwds:
self.text_kwds['transform'] = self.axis.transAxes
Expand Down Expand Up @@ -261,8 +274,6 @@ def __init__(self, Y, label, metadata, **kwargs):
if self.cmap is None:
self.cmap = 'husl'

print(self.category_kwds)

out = putl.process_categorical_label(metadata,
label,
cmap=self.cmap,
Expand All @@ -285,8 +296,6 @@ def _plot(self, **kwargs):

good_idx = self._labels == label
n_labs = sum(good_idx)
# if verbose:
# print(f"There are {n_labs} samples with label = {label}")

if label in self.labels_2_show:

Expand Down
6 changes: 6 additions & 0 deletions EMBEDR/plots/sweep_boxplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class SweepBoxplot(object):

BOX_PATCHES = ['boxes', 'whiskers', 'fliers', 'caps', 'medians']
GRID_KWDS = dict()

def __init__(self,
hyperparam_array,
Expand All @@ -23,6 +24,7 @@ def __init__(self,
back_wpad=0.0,
back_hpad=0.0,
fig_pad=0.4,
grid_kwds=None,
params_2_highlight=None,
box_color=None,
box_fliers=None,
Expand Down Expand Up @@ -100,6 +102,10 @@ def __init__(self,
self.hp_2_hl_idx = np.array([ii for ii, hp in enumerate(self.hp_array)
if hp in self.hp_2_hl]).astype(int)

self.grid_kwds = self.GRID_KWDS.copy()
if grid_kwds is not None:
self.grid_kwds.update(grid_kwds)

self.box_color = box_color
self.box_fliers = box_fliers
self.box_props = box_props
Expand Down
Loading

0 comments on commit d7f3dfa

Please sign in to comment.