Skip to content

Commit

Permalink
Committing updates to plotting utility.
Browse files Browse the repository at this point in the history
  • Loading branch information
ejohnson643 committed Dec 17, 2021
1 parent 76b918c commit 87b96b1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
12 changes: 12 additions & 0 deletions EMBEDR/embedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,18 @@ def plot(self,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

elif plot_type.lower() in ['keff']:
from EMBEDR.plots.embedr_scatterplots import Scatterplot

plotObj = Scatterplot(plot_Y,
self._kEff,
fig=fig,
axis=axis,
cbar_ax=cbar_ax,
show_cbar=show_cbar,
cite_EMBEDR=cite_EMBEDR,
**plot_kwds)

elif (metadata is not None) and (plot_type in metadata):

if is_categorical:
Expand Down
26 changes: 26 additions & 0 deletions EMBEDR/plotting_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,29 @@ def process_categorical_label(metadata, label,

return raw_labels, label_counts, long_labels, lab_2_idx_map, label_cmap


def get_DBSCAN_clusters(Y, min_samples=10, pwd_perc=1.5):
from sklearn.cluster import DBSCAN

PWD = pwd(Y, metric='euclidean')
PWD_triu = np.triu(PWD, k=1)
eps = np.percentile(PWD_triu[PWD_triu.nonzero()], pwd_perc)

DBObj = DBSCAN(eps=eps, min_samples=min_samples)
DBObj.fit(Y)

db_labels = DBObj.labels_

## Count the labels
raw_counts = Counter(db_labels)
## Sort in descending order
db_lab_counts = sorted(raw_counts.items(), key=lambda item: -item[1])
## Remove -1 label
db_lab_counts = {el[0]: el[1] for el in db_lab_counts if el[0] != -1}
## Create mapping from old labels to size-sorted labels
db_lab_remap = {old_lab: new_lab for new_lab, old_lab in enumerate(db_lab_counts.keys())}
## Add -1 to map
if -1 in raw_counts:
db_lab_remap[-1] = -1
## Remap labels
return np.asarray([db_lab_remap[old_lab] for old_lab in db_labels])

0 comments on commit 87b96b1

Please sign in to comment.