Skip to content

Commit

Permalink
Merge pull request #108 from CAMBI-tech/1.4.3-querying_stopping
Browse files Browse the repository at this point in the history
1.4.3 querying stopping
  • Loading branch information
azizkocana authored Dec 15, 2020
2 parents 33bd476 + ee7f74d commit 5d8542a
Show file tree
Hide file tree
Showing 13 changed files with 1,028 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from bcipy.signal.model.inference import inference
from bcipy.signal.process.filter import bandpass, notch, downsample
from bcipy.tasks.rsvp.main_frame import EvidenceFusion, DecisionMaker
from bcipy.tasks.rsvp.query_mechanisms import NBestStimuliAgent
from bcipy.tasks.rsvp.stopping_criteria import CriteriaEvaluator, \
MinIterationsCriteria, MaxIterationsCriteria, ProbThresholdCriteria
from bcipy.helpers.language_model import norm_domain, sym_appended, \
equally_probable, histogram

Expand Down Expand Up @@ -57,15 +60,26 @@ def __init__(self, min_num_seq, max_num_seq, signal_model=None, fs=300, k=2,
seq_constants = []
if backspace_always_shown and BACKSPACE_CHAR in alp:
seq_constants.append(BACKSPACE_CHAR)

# Stimuli Selection Module
stopping_criteria = CriteriaEvaluator(
continue_criteria=[MinIterationsCriteria(min_num_seq)],
commit_criteria=[MaxIterationsCriteria(max_num_seq),
ProbThresholdCriteria(decision_threshold)])

# TODO: Parametrize len_query in the future releases!
stimuli_agent = NBestStimuliAgent(alphabet=alp,
len_query=10)

self.decision_maker = DecisionMaker(
min_num_seq,
max_num_seq,
decision_threshold=decision_threshold,
stimuli_agent=stimuli_agent,
stopping_evaluator=stopping_criteria,
state=task_list[0][1],
alphabet=alp,
is_txt_stim=is_txt_stim,
stimuli_timing=stimuli_timing,
seq_constants=seq_constants)

self.alp = alp
# non-letter target labels include the fixation cross and calibration.
self.nonletters = ['+', 'PLUS', 'calibration_trigger']
Expand Down Expand Up @@ -100,7 +114,8 @@ def evaluate_sequence(self, raw_data, triggers, target_info, window_length):

# Remove 60hz noise with a notch filter
notch_filter_data = notch.notch_filter(
raw_data, self.sampling_rate, frequency_to_remove=self.notch_filter_frequency)
raw_data, self.sampling_rate,
frequency_to_remove=self.notch_filter_frequency)

# bandpass filter from 2-45hz
filtered_data = bandpass.butter_bandpass_filter(
Expand All @@ -113,7 +128,8 @@ def evaluate_sequence(self, raw_data, triggers, target_info, window_length):
# downsample
data = downsample.downsample(
filtered_data, factor=self.downsample_rate)
x, _, _, _ = trial_reshaper(target_info, times, data, fs=self.sampling_rate,
x, _, _, _ = trial_reshaper(target_info, times, data,
fs=self.sampling_rate,
k=self.downsample_rate, mode=self.mode,
channel_map=self.channel_map,
trial_length=window_length)
Expand Down Expand Up @@ -201,7 +217,8 @@ def initialize_epoch(self):
if alp_letter == prior_sym]

# display histogram of LM probabilities
print(f"Printed letters: '{self.decision_maker.displayed_state}'")
print(
f"Printed letters: '{self.decision_maker.displayed_state}'")
print(histogram(lm_letter_prior))

# Try fusing the lmodel evidence
Expand Down
2 changes: 1 addition & 1 deletion bcipy/helpers/demo/demo_eeg_model_related.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from bcipy.helpers.signal_model import CopyPhraseWrapper
from bcipy.helpers.copy_phrase_wrapper import CopyPhraseWrapper
from bcipy.signal.model.mach_learning.train_model import train_pca_rda_kde_model
from bcipy.helpers.task import alphabet

Expand Down
71 changes: 63 additions & 8 deletions bcipy/helpers/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,58 @@
logging.getLogger("PIL").setLevel(logging.WARNING)


# TODO: since we have a querying this should replace the other generators
def rsvp_seq_generator(query: list,
timing=[1, 0.2],
color=['red', 'white'],
stim_number=1,
is_txt=True,
seq_constants=None) -> tuple:
""" Given the query set, prepares the stimuli, color and timing
Args:
query(list[str]): list of queries to be shown
timing(list[float]): Task specific timing for generator
color(list[str]): Task specific color for generator
First element is the target, second element is the fixation
Observe that [-1] element represents the trial information
seq_constants(list[str]): list of letters that should always be
included in every sequence. If provided, must be alp items.
Return:
schedule_seq(tuple(
samples[list[list[str]]]: list of sequences
timing(list[list[float]]): list of timings
color(list(list[str])): list of colors)): scheduled sequences
"""

# shuffle the returned values
random.shuffle(query)

stim_length = len(query)

# Init some lists to construct our stimuli with
samples, times, colors = [], [], []
for idx_num in range(stim_number):

# append a fixation cross. if not text, append path to image fixation
if is_txt:
sample = ['+']
else:
sample = ['bcipy/static/images/bci_main_images/PLUS.png']

# construct the sample from the query
sample += [i for i in query]
samples.append(sample)

# append timing
times.append([timing[i] for i in range(len(timing) - 1)] +
[timing[-1]] * stim_length)

# append colors
colors.append([color[i] for i in range(len(color) - 1)] +
[color[-1]] * stim_length)
return (samples, times, colors)


def best_selection(selection_elements: list,
val: list,
len_query: int,
Expand All @@ -41,9 +93,9 @@ def best_selection(selection_elements: list,
best = sorted(selection_elements, key=elem_val.get, reverse=True)[0:n]

replacements = [
item for item in always_included
if item not in best and item in selection_elements
][0:n]
item for item in always_included
if item not in best and item in selection_elements
][0:n]

if replacements:
best[-len(replacements):] = replacements
Expand Down Expand Up @@ -83,7 +135,7 @@ def best_case_rsvp_seq_gen(alp: list,
if len(alp) != len(session_stimuli):
raise Exception('Missing information about alphabet. len(alp):{}, '
'len(session_stimuli):{}, should be same!'.format(
len(alp), len(session_stimuli)))
len(alp), len(session_stimuli)))

if seq_constants and not set(seq_constants).issubset(alp):
raise Exception('Sequence constants must be alphabet items.')
Expand Down Expand Up @@ -172,7 +224,8 @@ def random_rsvp_calibration_seq_gen(alp, timing=[0.5, 1, 0.2],
return schedule_seq


def target_rsvp_sequence_generator(alp, target_letter, parameters, timing=[0.5, 1, 0.2],
def target_rsvp_sequence_generator(alp, target_letter, parameters,
timing=[0.5, 1, 0.2],
color=['green', 'white', 'white'],
stim_length=10, is_txt=True):
"""Target RSVP Sequence Generator.
Expand Down Expand Up @@ -204,7 +257,7 @@ def target_rsvp_sequence_generator(alp, target_letter, parameters, timing=[0.5,
else:
sample = ['bcipy/static/images/bci_main_images/PLUS.png']
target_letter = parameters['path_to_presentation_images'] + \
target_letter + '.png'
target_letter + '.png'
sample += [alp[i] for i in rand_smp]

# if the target isn't in the array, replace it with some random index that
Expand Down Expand Up @@ -357,7 +410,8 @@ def generate_icon_match_images(
'bcipy/static/images/bci_main_images/PLUS.png')

# Add target image to sequence, if it is not already there
if not target_image_numbers[sequence] in random_number_array[2:experiment_length]:
if not target_image_numbers[sequence] in random_number_array[
2:experiment_length]:
random_number_array[np.random.randint(
2, experiment_length)] = target_image_numbers[sequence]

Expand Down Expand Up @@ -393,7 +447,8 @@ def resize_image(image_path: str, screen_size: tuple, sti_height: int):
proportions[0], sti_height * proportions[1])
else:
sti_size = (
sti_height * proportions[0], (screen_width / screen_height) * sti_height * proportions[1])
sti_height * proportions[0],
(screen_width / screen_height) * sti_height * proportions[1])

return sti_size

Expand Down
2 changes: 1 addition & 1 deletion bcipy/helpers/tests/test_signal_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from bcipy.helpers.signal_model import CopyPhraseWrapper
from bcipy.helpers.copy_phrase_wrapper import CopyPhraseWrapper
from bcipy.helpers.task import alphabet


Expand Down
2 changes: 1 addition & 1 deletion bcipy/tasks/rsvp/copy_phrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from bcipy.feedback.visual.visual_feedback import VisualFeedback
from bcipy.helpers.triggers import _write_triggers_from_sequence_copy_phrase
from bcipy.helpers.save import _save_session_related_data
from bcipy.helpers.signal_model import CopyPhraseWrapper
from bcipy.helpers.copy_phrase_wrapper import CopyPhraseWrapper
from bcipy.helpers.task import (
fake_copy_phrase_decision, alphabet, process_data_for_decision,
trial_complete_message, get_user_input)
Expand Down
2 changes: 1 addition & 1 deletion bcipy/tasks/rsvp/icon_to_icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bcipy.display.rsvp.mode.icon_to_icon import IconToIconDisplay
from bcipy.feedback.visual.visual_feedback import FeedbackType, VisualFeedback
from bcipy.helpers.save import _save_session_related_data
from bcipy.helpers.signal_model import CopyPhraseWrapper
from bcipy.helpers.copy_phrase_wrapper import CopyPhraseWrapper
from bcipy.helpers.task import (alphabet, generate_targets, get_user_input,
process_data_for_decision,
trial_complete_message)
Expand Down
Loading

0 comments on commit 5d8542a

Please sign in to comment.