Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backspace always shown #178

Merged
merged 2 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions bcipy/helpers/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def alphabetize(stimuli: List[str]) -> List[str]:
def rsvp_inq_generator(query: List[str],
timing: List[float] = [1, 0.2],
color: List[str] = ['red', 'white'],
stim_number: int = 1,
inquiry_count: int = 1,
stim_order: StimuliOrder = StimuliOrder.RANDOM,
is_txt: bool = True) -> InquirySchedule:
""" Given the query set, prepares the stimuli, color and timing
Expand All @@ -89,7 +89,7 @@ def rsvp_inq_generator(query: List[str],

# Init some lists to construct our stimuli with
samples, times, colors = [], [], []
for _ in range(stim_number):
for _ in range(inquiry_count):

# append a fixation cross. if not text, append path to image fixation
sample = [get_fixation(is_txt)]
Expand Down Expand Up @@ -127,15 +127,14 @@ def best_selection(selection_elements: list,
best_selection(list[str]): elements from selection_elements with the best values """

always_included = always_included or []
n = len_query
# pick the top n items sorted by value in decreasing order
elem_val = dict(zip(selection_elements, val))
best = sorted(selection_elements, key=elem_val.get, reverse=True)[0:n]
best = sorted(selection_elements, key=elem_val.get, reverse=True)[0:len_query]

replacements = [
item for item in always_included
if item not in best and item in selection_elements
][0:n]
][0:len_query]

if replacements:
best[-len(replacements):] = replacements
Expand Down Expand Up @@ -182,12 +181,9 @@ def best_case_rsvp_inq_gen(alp: list,
if inq_constants and not set(inq_constants).issubset(alp):
raise BciPyCoreException('Inquiry constants must be alphabet items.')

# create a list of alphabet letters
alphabet = [i for i in alp]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes in our codebase we do this to copy a list since it will be mutated by the method it's passed into, but it doesn't look like best_selection mutates its arguments. Maybe in a previous iteration? Anyway, nice cleanup.

# query for the best selection
query = best_selection(
alphabet,
alp,
session_stimuli,
stim_length,
inq_constants)
Expand Down
8 changes: 0 additions & 8 deletions bcipy/helpers/tests/test_copy_phrase_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,6 @@ def test_init_series_evaluate_inquiry(self):

is_accepted, sti = copy_phrase_task.initialize_series()
self.assertFalse(is_accepted)
self.assertEqual(
sti,
(
[["+", "U", "T", "_", "W", "Y", "X", "Z", "<", "S", "V"]],
[[self.params["time_cross"]] + [self.params["time_flash"]] * self.params["stim_length"]],
[[self.params["fixation_color"]] + [self.params["stim_color"]] * self.params["stim_length"]],
),
)

triggers = [
("+", 0.0),
Expand Down
7 changes: 4 additions & 3 deletions bcipy/task/control/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self,

self.last_selection = ''

# Items shown in every inquiry TODO this is unused
# Items shown in every inquiry
self.inq_constants = inq_constants

def reset(self, state=''):
Expand Down Expand Up @@ -278,10 +278,11 @@ def prepare_stimuli(self) -> InquirySchedule:

# querying agent decides on possible letters to be shown on the screen
query_els = self.stimuli_agent.return_stimuli(
self.list_series[-1]['list_distribution'])
self.list_series[-1]['list_distribution'],
constants=self.inq_constants)
# once querying is determined, append with timing and color info.
stimuli = rsvp_inq_generator(query=query_els,
stim_number=1,
inquiry_count=1,
is_txt=self.is_txt_stim,
timing=self.stimuli_timing,
stim_order=self.stimuli_order)
Expand Down
20 changes: 9 additions & 11 deletions bcipy/task/control/query.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import random
from typing import List, Any
from typing import List, Optional
from abc import ABC, abstractmethod

from bcipy.helpers.stimuli import best_selection


class StimuliAgent(ABC):
@abstractmethod
Expand Down Expand Up @@ -45,11 +47,14 @@ def reset(self):
""" This querying method is memoryless no reset needed """
pass

def return_stimuli(self, list_distribution: np.ndarray):
def return_stimuli(self, list_distribution: np.ndarray, constants: Optional[List[str]] = None):
""" return random elements from the alphabet """
tmp = [i for i in self.alphabet]
query = random.sample(tmp, self.len_query)

if constants:
query[-len(constants):] = constants

return query

def do_series(self):
Expand All @@ -75,19 +80,12 @@ def __init__(self, alphabet: List[str], len_query: int = 4):
def reset(self):
pass

def return_stimuli(self, list_distribution: np.ndarray):
def return_stimuli(self, list_distribution: np.ndarray, constants: Optional[List[str]] = None):
p = list_distribution[-1]
tmp = [i for i in self.alphabet]
query = best_selection(tmp, p, self.len_query)
query = best_selection(tmp, p, self.len_query, always_included=constants)

return query

def do_series(self):
pass


def best_selection(list_el: List[Any], val: List[float], len_query: int):
"""Return the top `len_query` items from `list_el` according to the values in `val`"""
# numpy version: return list_el[(-val).argsort()][:len_query]
sorted_items = reversed(sorted(zip(val, list_el)))
return [el for (value, el) in sorted_items][:len_query]
2 changes: 1 addition & 1 deletion bcipy/task/tests/core/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_best_selection(self):
list_el = ["A", "E", "I", "O", "U"]
values = [0.1, 0.2, 0.2, 0.2, 0.2]
len_query = 3
self.assertEqual(["U", "O", "I"], best_selection(list_el, values, len_query))
self.assertEqual(["E", "I", "O"], best_selection(list_el, values, len_query))


if __name__ == '__main__':
Expand Down