Skip to content

Commit

Permalink
Prestimulus, Inquiry Based Training, Model Tuning (#208)
Browse files Browse the repository at this point in the history
Co-authored-by: Niklas <[email protected]>
Co-authored-by: lawhead <[email protected]>
  • Loading branch information
3 people authored Apr 14, 2022
1 parent 31c9924 commit 6b672a4
Show file tree
Hide file tree
Showing 42 changed files with 1,247 additions and 1,554 deletions.
2 changes: 1 addition & 1 deletion bcipy/acquisition/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bcipy.acquisition.connection_method import ConnectionMethod
from bcipy.helpers.system_utils import auto_str, DEFAULT_ENCODING
IRREGULAR_RATE = 0.0
DEFAULT_CONFIG = 'bcipy/acquisition/devices.json'
DEFAULT_CONFIG = Path(__file__).resolve().parent / 'devices.json'
_SUPPORTED_DEVICES = {}
# see https://labstreaminglayer.readthedocs.io/projects/liblsl/ref/enums.html
SUPPORTED_DATA_TYPES = [
Expand Down
4 changes: 4 additions & 0 deletions bcipy/acquisition/tests/datastream/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ def count_generator(low=0, high=10, step=1):
gen4 = new_generator(step=2)
self.assertEqual(1, next(gen4))
self.assertEqual(3, next(gen4))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion bcipy/gui/BCInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def offline_analysis(self) -> None:
Run offline analysis as a script in a new process.
"""
cmd = 'python bcipy/signal/model/offline_analysis.py'
cmd = 'python bcipy/signal/model/offline_analysis.py --alert'
subprocess.Popen(cmd, shell=True)


Expand Down
2 changes: 1 addition & 1 deletion bcipy/helpers/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def max_inquiry_duration(parameters: dict) -> float:
target_duration = parameters['time_prompt']
stim_count = parameters['stim_length']
stim_duration = parameters['time_flash']
interval_duration = parameters['task_buffer_len']
interval_duration = parameters['task_buffer_length']

return target_duration + fixation_duration + (
stim_count * stim_duration) + interval_duration
Expand Down
1 change: 1 addition & 0 deletions bcipy/helpers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, List, Tuple

import numpy as np

from pyedflib import FILETYPE_EDFPLUS, EdfWriter
from tqdm import tqdm

Expand Down
15 changes: 8 additions & 7 deletions bcipy/helpers/copy_phrase_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
histogram,
with_min_prob,
)
from bcipy.helpers.stimuli import InquirySchedule, StimuliOrder
from bcipy.helpers.stimuli import InquirySchedule, StimuliOrder, TrialReshaper
from bcipy.helpers.task import BACKSPACE_CHAR
from bcipy.signal.model import SignalModel
from bcipy.signal.process import get_default_transform
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(self,

def evaluate_inquiry(
self, raw_data: np.array, triggers: List[Tuple[str, float]],
target_info: List[str], window_length: int
target_info: List[str], window_length: float
) -> Tuple[bool, Tuple[List[str], List[float], List[str]]]:
"""Once data is collected, infers meaning from the data and attempt to
make a decision.
Expand Down Expand Up @@ -172,7 +172,7 @@ def evaluate_inquiry(
def evaluate_eeg_evidence(self, raw_data: np.array,
triggers: List[Tuple[str, float]],
target_info: List[str],
window_length: int) -> np.array:
window_length: float) -> np.array:
"""Once data is collected, infers meaning from the data and return the results.
Parameters
Expand All @@ -198,16 +198,17 @@ def evaluate_eeg_evidence(self, raw_data: np.array,
bandpass_order=self.filter_order,
downsample_factor=self.downsample_rate,
)

data, transformed_sample_rate = default_transform(raw_data, self.sampling_rate)

data, _ = self.signal_model.reshaper(
trial_labels=target_info,
# The data from DAQ is assumed to have offsets applied
data, _ = TrialReshaper()(
trial_targetness_label=target_info,
timing_info=times,
eeg_data=data,
fs=transformed_sample_rate,
trials_per_inquiry=self.stim_length,
channel_map=self.channel_map,
trial_length=window_length)
poststimulus_length=window_length)

return self.signal_model.predict(data, letters, self.alp)

Expand Down
16 changes: 16 additions & 0 deletions bcipy/helpers/list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility functions for list processing."""
from typing import Callable, List
from itertools import zip_longest


def destutter(items: List, key: Callable = lambda x: x) -> List:
Expand All @@ -18,3 +19,18 @@ def destutter(items: List, key: Callable = lambda x: x) -> List:
else:
deduped[-1] = item
return deduped


def grouper(iterable, chunk_size, incomplete="fill", fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
chunks = [iter(iterable)] * chunk_size
if incomplete == "fill":
if fillvalue:
return zip_longest(*chunks, fillvalue=fillvalue)
raise ValueError('fillvalue must be defined if using incomplete=fill')
if incomplete == "ignore":
return zip(*chunks)

raise ValueError("Expected fill or ignore")
186 changes: 177 additions & 9 deletions bcipy/helpers/stimuli.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import glob
import itertools
import re
import logging
import random
import re

from abc import ABC, abstractmethod
from enum import Enum
from os import path, sep
from typing import Iterator, List, Tuple, NamedTuple

import numpy as np
from enum import Enum
from bcipy.helpers.exceptions import BciPyCoreException
from bcipy.helpers.list import grouper

import sounddevice as sd
import soundfile as sf
from PIL import Image

from bcipy.helpers.exceptions import BciPyCoreException
# Prevents pillow from filling the console with debug info
logging.getLogger('PIL').setLevel(logging.WARNING)

from psychopy import core
import numpy as np
import sounddevice as sd
import soundfile as sf


# Prevents pillow from filling the console with debug info
logging.getLogger('PIL').setLevel(logging.WARNING)
log = logging.getLogger(__name__)
DEFAULT_FIXATION_PATH = 'bcipy/static/images/main/PLUS.png'

Expand Down Expand Up @@ -66,6 +69,171 @@ class InquirySchedule(NamedTuple):
colors: List[List[str]]


class Reshaper(ABC):

@abstractmethod
def __call__(self):
...


class InquiryReshaper:
def __call__(self,
trial_targetness_label: List[str],
timing_info: List[float],
eeg_data: np.ndarray,
fs: int,
trials_per_inquiry: int,
offset: float = 0,
channel_map: List[int] = None,
poststimulus_length: float = 0.5,
prestimulus_length: float = 0.0,
transformation_buffer: float = 0.0,
target_label: str = 'target') -> Tuple[np.ndarray, np.ndarray]:
"""Extract inquiry data and labels.
Args:
trial_targetness_label (List[str]): labels each trial as "target", "non-target", "first_pres_target", etc
timing_info (List[float]): Timestamp of each event in seconds
eeg_data (np.ndarray): shape (channels, samples) preprocessed EEG data
fs (int): sample rate of EEG data. If data is downsampled, the sample rate should be also be downsampled.
trials_per_inquiry (int): number of trials in each inquiry
offset (float, optional): Any calculated or hypothesized offsets in timings. Defaults to 0.
channel_map (List[int], optional): Describes which channels to include or discard.
Defaults to None; all channels will be used.
poststimulus_length (float, optional): time in seconds needed after the last trial in an inquiry.
prestimulus_length (float, optional): time in seconds needed before the first trial in an inquiry.
transformation_buffer (float, optional): time in seconds to buffer the end of the inquiry. Defaults to 0.0.
target_label (str): label of target symbol. Defaults to "target"
Returns:
reshaped_data (np.ndarray): inquiry data of shape (Channels, Inquiries, Samples)
labels (np.ndarray): integer label for each inquiry. With `trials_per_inquiry=K`,
a label of [0, K-1] indicates the position of `target_label`, or label of K indicates
`target_label` was not present.
reshaped_trigger_timing (List[List[int]]): For each inquiry, a list of the sample index where each trial
begins, accounting for the prestim buffer that may have been added to the front of each inquiry.
"""
if channel_map:
# Remove the channels that we are not interested in
channels_to_remove = [idx for idx, value in enumerate(channel_map) if value == 0]
eeg_data = np.delete(eeg_data, channels_to_remove, axis=0)

n_inquiry = len(timing_info) // trials_per_inquiry
trial_duration_samples = int(poststimulus_length * fs)
prestimulus_samples = int(prestimulus_length * fs)

# triggers in seconds are mapped to triggers in number of samples.
triggers = list(map(lambda x: int((x + offset) * fs), timing_info))

# First, find the longest inquiry in this experiment
# We'll add or remove a few samples from all other inquiries, to match this length
def get_inquiry_len(inq_trigs):
return inq_trigs[-1] - inq_trigs[0]

longest_inquiry = max(grouper(triggers, trials_per_inquiry, fillvalue='x'), key=lambda xy: get_inquiry_len(xy))
num_samples_per_inq = get_inquiry_len(longest_inquiry) + trial_duration_samples
buffer_samples = int(transformation_buffer * fs)

# Label for every inquiry
labels = np.zeros(
(n_inquiry, trials_per_inquiry), dtype=np.compat.long
) # maybe this can be configurable? return either class indexes or labels ('nontarget' etc)
reshaped_data, reshaped_trigger_timing = [], []
for inquiry_idx, trials_within_inquiry in enumerate(
grouper(zip(trial_targetness_label, triggers), trials_per_inquiry, fillvalue='x')
):
first_trigger = trials_within_inquiry[0][1]

trial_triggers = []
for trial_idx, (trial_label, trigger) in enumerate(trials_within_inquiry):
if trial_label == target_label:
labels[inquiry_idx, trial_idx] = 1

# If presimulus buffer is used, we add it here so that trigger timings will
# still line up with trial onset
trial_triggers.append((trigger - first_trigger) + prestimulus_samples)
reshaped_trigger_timing.append(trial_triggers)
start = first_trigger - prestimulus_samples
stop = first_trigger + num_samples_per_inq + buffer_samples
reshaped_data.append(eeg_data[:, start:stop])

return np.stack(reshaped_data, 1), labels, reshaped_trigger_timing

@staticmethod
def extract_trials(inquiries, samples_per_trial, inquiry_timing, downsample_rate=1):
"""Extract Trials.
After using the InquiryReshaper, it may be necessary to futher trial the data for processing.
Using the number of samples and inquiry timing, the data is reshaped from Channels, Inquiry, Samples to
Channels, Trials, Samples. These should match with the trials extracted from the TrialReshaper given the same
slicing parameters.
"""
new_trials = []
num_inquiries = inquiries.shape[1]
for inquiry_idx, timing in zip(range(num_inquiries), inquiry_timing): # C x I x S

for time in timing:
time = time // downsample_rate
y = time + samples_per_trial
new_trials.append(inquiries[:, inquiry_idx, time:y])
return np.stack(new_trials, 1) # C x T x S


class TrialReshaper(Reshaper):
def __call__(self,
trial_targetness_label: list,
timing_info: list,
eeg_data: np.ndarray,
fs: int,
offset: float = 0,
channel_map: List[int] = None,
poststimulus_length: float = 0.5,
prestimulus_length: float = 0.0,
target_label: str = "target") -> Tuple[np.ndarray, np.ndarray]:
"""Extract trial data and labels.
Args:
trial_targetness_label (list): labels each trial as "target", "non-target", "first_pres_target", etc
timing_info (list): Timestamp of each event in seconds
eeg_data (np.ndarray): shape (channels, samples) preprocessed EEG data
fs (int): sample rate of preprocessed EEG data
trials_per_inquiry (int, optional): unused, kept here for consistent interface with `inquiry_reshaper`
offset (float, optional): Any calculated or hypothesized offsets in timings.
Defaults to 0.
channel_map (List, optional): Describes which channels to include or discard.
Defaults to None; all channels will be used.
poststimulus_length (float, optional): [description]. Defaults to 0.5.
target_label (str): label of target symbol. Defaults to "target"
Returns:
trial_data (np.ndarray): shape (channels, trials, samples) reshaped data
labels (np.ndarray): integer label for each trial
"""
# Remove the channels that we are not interested in
if channel_map:
channels_to_remove = [idx for idx, value in enumerate(channel_map) if value == 0]
eeg_data = np.delete(eeg_data, channels_to_remove, axis=0)

# Number of samples we are interested per trial
poststim_samples = int(poststimulus_length * fs)
prestim_samples = int(prestimulus_length * fs)

# triggers in seconds are mapped to triggers in number of samples.
triggers = list(map(lambda x: int((x + offset) * fs), timing_info))

# Label for every trial in 0 or 1
targetness_labels = np.zeros(len(triggers), dtype=np.compat.long)
reshaped_trials = []
for trial_idx, (trial_label, trigger) in enumerate(zip(trial_targetness_label, triggers)):
if trial_label == target_label:
targetness_labels[trial_idx] = 1

# For every channel append filtered channel data to trials
reshaped_trials.append(eeg_data[:, trigger - prestim_samples: trigger + poststim_samples])

return np.stack(reshaped_trials, 1), targetness_labels


def alphabetize(stimuli: List[str]) -> List[str]:
"""Alphabetize.
Expand Down
Loading

0 comments on commit 6b672a4

Please sign in to comment.