Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge adjacent epoch flags into a single annotation #151

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 21 additions & 83 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,72 +346,6 @@ def chan_neighbour_r(epochs, nneigbr, method):
return m_neigbr_r.rename(ref_chan="ch")


# TODO: check that annot type contains all unique flags
def marks_flag_gap(
raw, min_gap_ms, included_annot_type=None, out_annot_name="bad_pylossless_gap"
):
"""Mark small gaps in time between pylossless annotations.

Parameters
----------
raw : mne.Raw
An instance of mne.Raw
min_gap_ms : int
Time in milleseconds. If the time between two consecutive pylossless
annotations is less than this value, that time period will be
annotated.
included_annot_type : str (Default None)
Descriptions of the `mne.Annotations` in the `mne.Raw` to be included.
If `None`, includes ('bad_pylossless_ch_sd', 'bad_pylossless_low_r',
'bad_pylossless_ic_sd1', 'bad_pylossless_gap').
out_annot_name : str (default 'bad_pylossless_gap')
The description for the `mne.Annotation` That is created for any gaps.

Returns
-------
Annotations : `mne.Annotations`
An instance of `mne.Annotations`
"""
if included_annot_type is None:
included_annot_type = (
"bad_pylossless_ch_sd",
"bad_pylossless_low_r",
"bad_pylossless_ic_sd1",
"bad_pylossless_gap",
)

if len(raw.annotations) == 0:
return mne.Annotations([], [], [], orig_time=raw.annotations.orig_time)

ret_val = np.array(
[
[annot["onset"], annot["duration"]]
for annot in raw.annotations
if annot["description"] in included_annot_type
]
).T

if len(ret_val) == 0:
return mne.Annotations([], [], [], orig_time=raw.annotations.orig_time)

onsets, durations = ret_val
offsets = onsets + durations
gaps = np.array(
[
min(onset - offsets[offsets < onset]) if np.sum(offsets < onset) else np.inf
for onset in onsets[1:]
]
)
gap_mask = gaps < min_gap_ms / 1000

return mne.Annotations(
onset=onsets[1:][gap_mask] - gaps[gap_mask],
duration=gaps[gap_mask],
description=out_annot_name,
orig_time=raw.annotations.orig_time,
)


def coregister(
raw_edf,
fiducials="estimated", # get fiducials from fsaverage
Expand Down Expand Up @@ -645,14 +579,29 @@ def add_pylossless_annotations(self, inds, event_type, epochs):
"""
# Concatenate epoched data back to continuous data
t_onset = epochs.events[inds, 0] / epochs.info["sfreq"]
df = pd.DataFrame(t_onset, columns=["onset"])
# We exclude the last sample from the duration because
# if the annot lasts the whole duration of the epoch
# it's end will coincide with the first sample of the
# next epoch, causing it to erroneously be rejected.
duration = np.ones_like(t_onset) / epochs.info["sfreq"] * len(epochs.times[:-1])
description = [f"bad_pylossless_{event_type}"] * len(t_onset)
df["duration"] = 1 / epochs.info["sfreq"] * len(epochs.times[:-1])
df["description"] = f"bad_pylossless_{event_type}"

# Merge close onsets to prevent a bunch of 1-second annotations of the same name
# find onsets close enough to be considered the same
df["close"] = df.sort_values("onset")["onset"].diff().le(1)
df["group"] = ~df["close"]
df["group"] = df["group"].cumsum()
# group the close onsets and merge them
df["onset"] = df.groupby("group")["onset"].transform("first")
df["duration"] = df.groupby("group")["duration"].transform("sum")
df = df.drop_duplicates(subset=["onset", "duration"])

annotations = mne.Annotations(
t_onset, duration, description, orig_time=self.raw.annotations.orig_time
df["onset"],
df["duration"],
df["description"],
orig_time=self.raw.annotations.orig_time,
)
self.raw.set_annotations(self.raw.annotations + annotations)

Expand Down Expand Up @@ -1025,11 +974,6 @@ def flag_epoch_low_r(self):
logger.info(f"📋 LOSSLESS: Uncorrelated epochs: {bad_epoch_inds}")
self.flags["epoch"].add_flag_cat("low_r", bad_epoch_inds, epochs)

def flag_epoch_gap(self):
"""Flag small time periods between pylossless annotations."""
annots = marks_flag_gap(self.raw, self.config["epoch_gap"]["min_gap_ms"])
self.raw.set_annotations(self.raw.annotations + annots)

@lossless_logger
def run_ica(self, run):
"""Run ICA.
Expand Down Expand Up @@ -1218,21 +1162,15 @@ def _run(self):
# 9. Calculate nearest neighbour R values for epochs
self.flag_epoch_low_r(message="Flagging Uncorrelated epochs")

# 10. Flag very small time periods between flagged time
self.flag_epoch_gap()

# 11. Run ICA
# 10. Run ICA
self.run_ica("run1", message="Running Initial ICA")

# 12. Calculate IC SD
# 11. Calculate IC SD
self.flag_epoch_ic_sd1(message="Flagging time periods with noisy" " IC's.")

# 13. TODO: integrate labels from IClabels to self.flags["ic"]
# 12. TODO: integrate labels from IClabels to self.flags["ic"]
self.run_ica("run2", message="Running Final ICA.")

# 14. Flag very small time periods between flagged time
self.flag_epoch_gap()

def run_dataset(self, paths):
"""Run a full dataset.

Expand Down
Loading