Skip to content

Commit

Permalink
Merge pull request #332 from CAMBI-tech/task_registry_refactor
Browse files Browse the repository at this point in the history
Task registry refactor
  • Loading branch information
Carsonthemonkey authored Jul 18, 2024
2 parents 6214dc0 + 0cbb8e1 commit 0da74ab
Show file tree
Hide file tree
Showing 23 changed files with 174 additions and 239 deletions.
1 change: 0 additions & 1 deletion bcipy/gui/experiments/ExperimentField.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ def start_app() -> None:
"""Start Experiment Field Collection."""
import argparse
from bcipy.config import DEFAULT_EXPERIMENT_ID, EXPERIMENT_DATA_FILENAME
from bcipy.helpers.validate import validate_experiment, validate_field_data_written

parser = argparse.ArgumentParser()

Expand Down
24 changes: 13 additions & 11 deletions bcipy/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
import multiprocessing
from typing import List, Optional
from typing import List, Optional, Type

from psychopy import visual

Expand All @@ -20,16 +20,17 @@
from bcipy.helpers.task import print_message
from bcipy.helpers.validate import validate_bcipy_session, validate_experiment
from bcipy.helpers.visualization import visualize_session_data
from bcipy.task import TaskType
from bcipy.task import TaskRegistry, Task
from bcipy.task.start_task import start_task

log = logging.getLogger(__name__)
task_registry = TaskRegistry()


def bci_main(
parameter_location: str,
user: str,
task: TaskType,
task: Type[Task],
experiment: str = DEFAULT_EXPERIMENT_ID,
alert: bool = False,
visualize: bool = True,
Expand All @@ -49,7 +50,7 @@ def bci_main(
Input:
parameter_location (str): location of parameters file to use
user (str): name of the user
task (TaskType): registered bcipy TaskType
task (Task): registered bcipy Task
experiment_id (str): Name of the experiment. Default name is DEFAULT_EXPERIMENT_ID.
alert (bool): whether to alert the user when the task is complete
visualize (bool): whether to visualize data at the end of a task
Expand Down Expand Up @@ -84,7 +85,7 @@ def bci_main(
parameters['data_save_loc'],
user,
parameter_location,
task=task.label,
task=task.name,
experiment_id=experiment)

# configure bcipy session logging
Expand All @@ -110,7 +111,7 @@ def bci_main(


def execute_task(
task: TaskType,
task: Type[Task],
parameters: dict,
save_folder: str,
alert: bool,
Expand All @@ -122,7 +123,7 @@ def execute_task(
which will initialize experiment.
Input:
task(TaskType): Task that should be registered in TaskType
task(Task): Task that should be registered in TaskRegistry
parameters (dict): parameter dictionary
save_folder (str): path to save folder
alert (bool): whether to alert the user when the task is complete
Expand All @@ -136,7 +137,7 @@ def execute_task(

# Init EEG Model, if needed. Calibration Tasks Don't require probabilistic
# modules to be loaded.
if task not in TaskType.calibration_tasks():
if task not in task_registry.calibration_tasks():
# Try loading in our signal_model and starting a langmodel(if enabled)
if not fake:
try:
Expand Down Expand Up @@ -222,9 +223,9 @@ def bcipy_main() -> None:
"""
# Needed for windows machines
multiprocessing.freeze_support()

tr = TaskRegistry()
experiment_options = list(load_experiments().keys())
task_options = TaskType.list()
task_options = tr.list()
parser = argparse.ArgumentParser()

# Command line utility for adding arguments/ paths via command line
Expand Down Expand Up @@ -259,7 +260,8 @@ def bcipy_main() -> None:
args = parser.parse_args()

# Start BCI Main
bci_main(args.parameters, str(args.user), TaskType.by_value(str(args.task)),
task = task_registry.get(args.task)
bci_main(args.parameters, str(args.user), task,
str(args.experiment), args.alert, args.noviz, args.fake)


Expand Down
19 changes: 7 additions & 12 deletions bcipy/orchestrator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@

from typing import List, Type
from bcipy.task import Task
from bcipy.orchestrator.actions import task_registry_dict
from bcipy.config import TASK_SEPERATOR

# This is a temporary solution and will be improved in the refactored `TaskRegistry` class.
task_name_dict = {v: k for k, v in task_registry_dict.items()}
from bcipy.task.task_registry import TaskRegistry


def parse_sequence(sequence: str) -> List[Type[Task]]:
Expand All @@ -26,11 +23,8 @@ def parse_sequence(sequence: str) -> List[Type[Task]]:
List[TaskType]
A list of TaskType objects that represent the actions in the input string.
"""
try:
task_sequence = [task_registry_dict[task.strip()] for task in sequence.split(TASK_SEPERATOR)]
except KeyError as e:
raise ValueError('Invalid task name in action sequence') from e
return task_sequence
task_registry = TaskRegistry()
return [task_registry.get(item.strip()) for item in sequence.split(TASK_SEPERATOR)]


def validate_sequence_string(action_sequence: str) -> None:
Expand All @@ -50,8 +44,8 @@ def validate_sequence_string(action_sequence: str) -> None:
If the string of actions is invalid.
"""
for sequence_item in action_sequence.split(TASK_SEPERATOR):
if sequence_item.strip() not in task_registry_dict:
raise ValueError('Invalid task name in action sequence')
if sequence_item.strip() not in TaskRegistry().list():
raise ValueError(f"Invalid task '{sequence_item}' name in action sequence")


def serialize_sequence(sequence: List[Type[Task]]) -> str:
Expand All @@ -71,7 +65,8 @@ def serialize_sequence(sequence: List[Type[Task]]) -> str:
List[TaskType]
A list of TaskType objects that represent the actions in the input string.
"""
return f" {TASK_SEPERATOR} ".join([task_name_dict[item] for item in sequence])

return f" {TASK_SEPERATOR} ".join([item.name for item in sequence])


if __name__ == '__main__':
Expand Down
9 changes: 4 additions & 5 deletions bcipy/orchestrator/demo/demo_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from bcipy.orchestrator.orchestrator import SessionOrchestrator
from bcipy.orchestrator.actions import OfflineAnalysisAction, ExperimentFieldCollectionAction
from bcipy.config import DEFAULT_EXPERIMENT_ID
from bcipy.config import DEFAULT_PARAMETER_FILENAME

from bcipy.config import DEFAULT_EXPERIMENT_ID, DEFAULT_PARAMETER_FILENAME
from bcipy.helpers.load import load_experimental_data
from bcipy.orchestrator.orchestrator import SessionOrchestrator
from bcipy.task.actions import (ExperimentFieldCollectionAction,
OfflineAnalysisAction)


def demo_orchestrator(data_path: str, parameters_path: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions bcipy/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class SessionOrchestrator:
sys_info: dict
log: Logger
save_folder: Optional[str] = None
session_data: List[str] # This may need to be a list of dictionaries or objects here in the future
# This will need to be refactored to a more complex data structure to store data from each task
session_data: List[str]
# Session Orchestrator will contain global objects here (DAQ, models etc) to be shared between executed tasks.

def __init__(
Expand Down Expand Up @@ -80,7 +81,7 @@ def execute(self) -> None:
self.save()

def init_orchestrator_save_folder(self, save_path: str) -> None:
timestamp = str(datetime.now())
timestamp = datetime.now().strftime("%Y-%m-%d %H-%M")
# * No '/' after `save_folder` since it is included in
# * `data_save_location` in parameters
path = f'{save_path}{self.experiment_id}/{self.user}/orchestrator-run-{timestamp}/'
Expand Down
55 changes: 27 additions & 28 deletions bcipy/orchestrator/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from bcipy.orchestrator.config import parse_sequence, serialize_sequence, validate_sequence_string
from bcipy.task import TaskType
from bcipy.orchestrator.actions import OfflineAnalysisAction
from bcipy.task.actions import OfflineAnalysisAction
from bcipy.task.paradigm.rsvp.calibration.calibration import RSVPCalibrationTask
from bcipy.task.paradigm.rsvp.copy_phrase import RSVPCopyPhraseTask


class TestTaskProtocolProcessing(unittest.TestCase):
Expand All @@ -10,10 +11,10 @@ def test_parses_one_task(self) -> None:
sequence = 'RSVP Calibration'
parsed = parse_sequence(sequence)
assert len(parsed) == 1
assert parsed[0] == TaskType.RSVP_CALIBRATION
assert parsed[0] is RSVPCalibrationTask

def test_parses_one_action(self) -> None:
actions = 'Offline Analysis'
def test_parses_with_task_name(self) -> None:
actions = OfflineAnalysisAction.name
parsed = parse_sequence(actions)
assert len(parsed) == 1
assert parsed[0] is OfflineAnalysisAction
Expand All @@ -22,24 +23,32 @@ def test_parses_multiple_tasks(self) -> None:
actions = 'RSVP Calibration -> RSVP Copy Phrase'
parsed = parse_sequence(actions)
assert len(parsed) == 2
assert parsed[0] == TaskType.RSVP_CALIBRATION
assert parsed[1] == TaskType.RSVP_COPY_PHRASE
assert parsed[0] is RSVPCalibrationTask
assert parsed[1] is RSVPCopyPhraseTask

def test_parses_actions_and_tasks(self) -> None:
sequence = 'RSVP Calibration -> Offline Analysis -> RSVP Copy Phrase'
sequence = 'RSVP Calibration -> Offline Analysis Action -> RSVP Copy Phrase'
parsed = parse_sequence(sequence)
assert len(parsed) == 3
assert parsed[0] == TaskType.RSVP_CALIBRATION
assert parsed[0] is RSVPCalibrationTask
assert parsed[1] is OfflineAnalysisAction
assert parsed[2] == TaskType.RSVP_COPY_PHRASE
assert parsed[2] is RSVPCopyPhraseTask

def test_throws_exception_on_invalid_action(self) -> None:
def test_parses_sequence_with_extra_spaces(self) -> None:
actions = ' RSVP Calibration -> Offline Analysis Action -> RSVP Copy Phrase '
parsed = parse_sequence(actions)
assert len(parsed) == 3
assert parsed[0] is RSVPCalibrationTask
assert parsed[1] is OfflineAnalysisAction
assert parsed[2] is RSVPCopyPhraseTask

def test_throws_exception_on_invalid_task(self) -> None:
actions = 'RSVP Calibration -> does not exist'
with self.assertRaises(ValueError):
parse_sequence(actions)

def test_throws_exception_on_invalid_task(self) -> None:
actions = 'RSVP Calibration -> RSVP Copy Phrase -> does not exist'
def test_throws_exception_on_invalid_string(self) -> None:
actions = 'thisstringisbad'
with self.assertRaises(ValueError):
parse_sequence(actions)

Expand All @@ -53,21 +62,11 @@ def test_throws_exception_on_invalid_action_string(self) -> None:
validate_sequence_string(actions)

def test_serializes_one_task(self) -> None:
actions = [TaskType.RSVP_CALIBRATION]
actions = [RSVPCalibrationTask]
serialized = serialize_sequence(actions)
assert serialized == 'RSVP Calibration'

def test_serializes_one_action(self) -> None:
sequence = [OfflineAnalysisAction]
serialized = serialize_sequence(sequence)
assert serialized == 'Offline Analysis'

def test_serializes_actions_and_tasks(self) -> None:
sequence = [TaskType.RSVP_CALIBRATION, OfflineAnalysisAction, TaskType.RSVP_COPY_PHRASE]
serialized = serialize_sequence(sequence)
assert serialized == 'RSVP Calibration -> Offline Analysis -> RSVP Copy Phrase'
assert serialized == RSVPCalibrationTask.name

def test_serializes_multiple_tasks(self) -> None:
actions = [TaskType.RSVP_CALIBRATION, TaskType.RSVP_COPY_PHRASE]
serialized = serialize_sequence(actions)
assert serialized == 'RSVP Calibration -> RSVP Copy Phrase'
sequence = [RSVPCalibrationTask, OfflineAnalysisAction, RSVPCopyPhraseTask]
serialized = serialize_sequence(sequence)
assert serialized == 'RSVP Calibration -> Offline Analysis Action -> RSVP Copy Phrase'
6 changes: 4 additions & 2 deletions bcipy/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
This import statement allows users to import submodules from Task
"""
from .main import Task
from .task_registry import TaskType

# Makes the following classes available to the task registry
from .task_registry import TaskRegistry

__all__ = [
'Task',
'TaskType',
'TaskRegistry'
]
Loading

0 comments on commit 0da74ab

Please sign in to comment.