Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: simple HTML representation for LosslessPipeline #146

Merged
merged 7 commits into from
Oct 30, 2023
2 changes: 1 addition & 1 deletion pylossless/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Python port of EEG-IP-L pipeline for preprocessing EEG."""

from . import pipeline, config, bids
from . import pipeline, bids, config, flagging, utils
from .pipeline import LosslessPipeline
4 changes: 2 additions & 2 deletions pylossless/dash/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Dash based helper functions for Lossless QC procedure."""

from mne_icalabel.config import ICLABEL_LABELS_TO_MNE
from mne_icalabel.config import ICA_LABELS_TO_MNE

IC_COLORS = [
"#2c2c2c",
Expand All @@ -16,4 +16,4 @@
"plum",
]

ic_label_cmap = dict(zip(ICLABEL_LABELS_TO_MNE.values(), IC_COLORS))
ic_label_cmap = dict(zip(ICA_LABELS_TO_MNE.values(), IC_COLORS))
8 changes: 4 additions & 4 deletions pylossless/dash/qcgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def set_visualizers(self):

def update_bad_ics(self, annotator="manual"):
"""Add IC name to raw.info['bads'] after selection by user in app."""
df = self.pipeline.flags["ic"].data_frame
df = self.pipeline.flags["ic"]
manual_labels_df = pd.DataFrame(
dict(
component=self.raw_ica.info["bads"],
Expand All @@ -157,7 +157,7 @@ def update_bad_ics(self, annotator="manual"):
)
)
df = pd.concat((df[df.annotator != annotator], manual_labels_df))
self.pipeline.flags["ic"].data_frame = df
self.pipeline.flags["ic"].__init__(df)

def set_layout(self, disable_buttons=False):
"""Create the app.layout for the app object.
Expand Down Expand Up @@ -272,7 +272,7 @@ def load_recording(self, fpath, verbose=False):
self.raw_ica = mne.io.RawArray(sources, info, verbose=verbose)
self.raw_ica.set_meas_date(self.raw.info["meas_date"])
self.raw_ica.set_annotations(self.raw.annotations)
df = self.pipeline.flags["ic"].data_frame
df = self.pipeline.flags["ic"]

bads = [
ic_name
Expand All @@ -283,7 +283,7 @@ def load_recording(self, fpath, verbose=False):
else:
self.raw_ica = None

df = self.pipeline.flags["ic"].data_frame
df = self.pipeline.flags["ic"]
self.ic_types = df[df.annotator == "ic_label"]
self.ic_types = self.ic_types.set_index("component")["ic_type"]
self.ic_types = self.ic_types.to_dict()
Expand Down
21 changes: 15 additions & 6 deletions pylossless/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import mne_icalabel

from ._utils import _icalabel_to_data_frame
from .utils._utils import _icalabel_to_data_frame


class FlaggedChs(dict):
Expand Down Expand Up @@ -52,6 +52,16 @@
super().__init__(*args, **kwargs)
self.ll = ll

def __repr__(self):
"""Return a string representation of the FlaggedChs object."""
return (
f"Flagged channels: |\n"
f" Noisy: {self.get('ch_sd', None)}\n"
f" Bridged: {self.get('bridge', None)}\n"
f" Uncorrelated: {self.get('low_r', None)}\n"
f" Rank: {self.get('rank', None)}\n"
)

def add_flag_cat(self, kind, bad_ch_names, *args):
"""Store channel names that have been flagged by pipeline.

Expand Down Expand Up @@ -197,7 +207,7 @@
self[annot["description"]].append(inds)


class FlaggedICs(dict):
class FlaggedICs(pd.DataFrame):
"""Object for handling IC classification in an mne ICA object.

Attributes
Expand Down Expand Up @@ -239,7 +249,6 @@
"""
super().__init__(*args, **kwargs)
self.fname = None
self.data_frame = None

def label_components(self, epochs, ica):
"""Classify components using mne_icalabel.
Expand All @@ -257,7 +266,7 @@
:func:`mne_icalabel.label_components`. Must be one of: `'iclabel'`.
"""
mne_icalabel.label_components(epochs, ica, method="iclabel")
self.data_frame = _icalabel_to_data_frame(ica)
self.__init__(_icalabel_to_data_frame(ica))

def save_tsv(self, fname):
"""Save IC labels.
Expand All @@ -268,12 +277,12 @@
The output filename.
"""
self.fname = fname
self.data_frame.to_csv(fname, sep="\t", index=False, na_rep="n/a")
self.to_csv(fname, sep="\t", index=False, na_rep="n/a")

Check warning on line 280 in pylossless/flagging.py

View check run for this annotation

Codecov / codecov/patch

pylossless/flagging.py#L280

Added line #L280 was not covered by tests

# TODO: Add parameters.
def load_tsv(self, fname, data_frame=None):
"""Load flagged ICs from file."""
self.fname = fname
if data_frame is None:
data_frame = pd.read_csv(fname, sep="\t")
self.data_frame = data_frame
self.__init__(data_frame)

Check warning on line 288 in pylossless/flagging.py

View check run for this annotation

Codecov / codecov/patch

pylossless/flagging.py#L288

Added line #L288 was not covered by tests
53 changes: 53 additions & 0 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .config import Config
from .flagging import FlaggedChs, FlaggedEpochs, FlaggedICs
from ._logging import lossless_logger, lossless_time
from .utils.html import _get_ics, _sum_flagged_times, _create_html_details


def epochs_to_xr(epochs, kind="ch", ica=None):
Expand Down Expand Up @@ -530,6 +531,58 @@ def __init__(self, config_fname=None):
self.ica1 = None
self.ica2 = None

def _repr_html_(self):
ch_flags = self.flags.get("ch", None)
df = self.flags["ic"]

eog = _get_ics(df, "eog")
ecg = _get_ics(df, "ecg")
muscle = _get_ics(df, "muscle")
line_noise = _get_ics(df, "line_noise")
channel_noise = _get_ics(df, "channel_noise")

lossless_flags = [
"bad_pylossless_ch_sd",
"bad_pylossless_low_r",
"bad_pylossless_ic_sd1",
]
flagged_times = _sum_flagged_times(self.raw, lossless_flags)

config_fname = self.config_fname
raw = self.raw.filenames if self.raw else "Not specified"

html = "<h3>LosslessPipeline</h3>"
html += "<table>"
html += f"<tr><td><strong>Raw</strong></td><td>{raw}</td></tr>"
html += f"<tr><td><strong>Config</strong></td><td>{config_fname}</td></tr>"
html += "</table>"

# Flagged Channels
flagged_channels_data = {
"Noisy": ch_flags.get("ch_sd", None),
"Bridged": ch_flags.get("bridge", None),
"Uncorrelated": ch_flags.get("low_r", None),
}
html += _create_html_details("Flagged Channels", flagged_channels_data)

# Flagged ICs
flagged_ics_data = {
"EOG (Eye)": eog,
"ECG (Heart)": ecg,
"Muscle": muscle,
"Line Noise": line_noise,
"Channel Noise": channel_noise,
}
html += _create_html_details("Flagged ICs", flagged_ics_data)

# Flagged Times
flagged_times_data = flagged_times
html += _create_html_details(
"Flagged Times (Total)", flagged_times_data, times=True
)

return html

def load_config(self):
"""Load the config file."""
self.config = Config().read(self.config_fname)
Expand Down
13 changes: 13 additions & 0 deletions pylossless/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,22 @@ def load_openneuro_bids():
return raw, config, bids_root


def test_empty_repr(tmp_path):
"""Test the __repr__ method for a pipeline that hasn't run."""
config = ll.config.Config()
config.load_default()
fpath = tmp_path / "test_config.yaml"
config.save(fpath)
pipeline = ll.LosslessPipeline(fpath)
assert pipeline.__repr__()
assert pipeline.flags["ch"].__repr__()


def test_pipeline_run(pipeline_fixture):
"""Test running the pipeline."""
assert "BAD_break" in pipeline_fixture.raw.annotations.description
assert pipeline_fixture._repr_html_()
assert pipeline_fixture.flags["ch"].__repr__()


@pytest.mark.parametrize("logging", [True, False])
Expand Down
2 changes: 2 additions & 0 deletions pylossless/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._utils import _icalabel_to_data_frame
from .html import _get_ics, _sum_flagged_times, _create_html_details
File renamed without changes.
32 changes: 32 additions & 0 deletions pylossless/utils/html.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np


def _get_ics(df, ic_type):
if not df.empty:
return df[df["ic_type"] == ic_type]["component"].tolist()
return None

Check warning on line 7 in pylossless/utils/html.py

View check run for this annotation

Codecov / codecov/patch

pylossless/utils/html.py#L7

Added line #L7 was not covered by tests


def _sum_flagged_times(raw, flags):
"""Sum the total time flagged for various flags like noisy etc."""
flag_dict = {}
for flag in flags:
flag_dict[flag] = []
if raw:
inds = np.where(raw.annotations.description == flag)[0]
if len(inds):
flag_dict[flag] = np.sum(raw.annotations.duration[inds])
return flag_dict


def _create_html_details(title, data, times=False):
html_details = f"<details><summary><strong>{title}</strong></summary>"
html_details += "<table>"
for key, value in data.items():
if times: # special format for flagged times
value = f"{value:.2f} s" if value else value
html_details += f"<tr><td>{key}</td><td>{value} seconds</td></tr>"
else: # Channels, ICs
html_details += f"<tr><td>{key}</td><td>{value}</td></tr>"
html_details += "</table></details>"
return html_details
Loading