From d78fbc3730cf42463938a4b2ae4144bfd60ff348 Mon Sep 17 00:00:00 2001 From: Scott Huberty <52462026+scott-huberty@users.noreply.github.com> Date: Sun, 30 Apr 2023 21:46:44 -0400 Subject: [PATCH 01/17] Add or update the Azure App Service build and deployment workflow config --- .github/workflows/main_pylossless-qc-demo.yml | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 .github/workflows/main_pylossless-qc-demo.yml diff --git a/.github/workflows/main_pylossless-qc-demo.yml b/.github/workflows/main_pylossless-qc-demo.yml new file mode 100644 index 0000000..04b2245 --- /dev/null +++ b/.github/workflows/main_pylossless-qc-demo.yml @@ -0,0 +1,63 @@ +# Docs for the Azure Web Apps Deploy action: https://github.com/Azure/webapps-deploy +# More GitHub Actions for Azure: https://github.com/Azure/actions +# More info on Python, GitHub Actions, and Azure App Service: https://aka.ms/python-webapps-actions + +name: Build and deploy Python app to Azure Web App - pylossless-qc-demo + +on: + push: + branches: + - main + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python version + uses: actions/setup-python@v1 + with: + python-version: '3.7' + + - name: Create and start virtual environment + run: | + python -m venv venv + source venv/bin/activate + + - name: Install dependencies + run: pip install -r requirements.txt + + # Optional: Add step to run tests here (PyTest, Django test suites, etc.) + + - name: Upload artifact for deployment jobs + uses: actions/upload-artifact@v2 + with: + name: python-app + path: | + . + !venv/ + + deploy: + runs-on: ubuntu-latest + needs: build + environment: + name: 'Production' + url: ${{ steps.deploy-to-webapp.outputs.webapp-url }} + + steps: + - name: Download artifact from build job + uses: actions/download-artifact@v2 + with: + name: python-app + path: . + + - name: 'Deploy to Azure Web App' + uses: azure/webapps-deploy@v2 + id: deploy-to-webapp + with: + app-name: 'pylossless-qc-demo' + slot-name: 'Production' + publish-profile: ${{ secrets.AZUREAPPSERVICE_PUBLISHPROFILE_B1F7B3988A884792AC90F3E962E01A2B }} From 6bda8832d7fc250873eea50331628aa0cd7268ad Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 1 Sep 2023 10:55:21 -0400 Subject: [PATCH 02/17] MAINT: update testing requirements and configuration - Use black to autoformat code - swap out flake8 for Ruff for linting - add precommit hook to test for style errors before commiting or pushing - add a pyproject toml file to configure black, ruff, codespell, etc - update github workflows style test --- .github/workflows/check_linting.yml | 21 +++++++++++---------- .pre-commit-config.yaml | 24 ++++++++++++++++++++++++ pyproject.toml | 21 +++++++++++++++++++++ requirements_testing.txt | 14 +++++++++----- 4 files changed, 65 insertions(+), 15 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml diff --git a/.github/workflows/check_linting.yml b/.github/workflows/check_linting.yml index 701022f..4634d9f 100644 --- a/.github/workflows/check_linting.yml +++ b/.github/workflows/check_linting.yml @@ -1,22 +1,23 @@ -name: PEP8 Compliance, Spellcheck, & Docstring +name: Style, Spellcheck, & Docstring on: pull_request jobs: - linting: + style: + name: Style runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.x uses: actions/setup-python@v4 with: - python-version: "3.x" + python-version: '3.11' + - uses: psf/black@stable + - uses: pre-commit/action@v3.0.0 - name: Install dependencies run: pip install -r requirements_testing.txt - - name: Run flake8 - run: flake8 pylossless docs --exclude pylossless/__init__.py + # Run Ruff + - name: Run Ruff + run: ruff pylossless + # Run Codespell - name: Run Codespell - run: codespell pylossless docs --skip docs/source/generated - - name: Check Numpy Format Docstring - run: pydocstyle pylossless - + run: codespell pylossless docs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..252f7a1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + args: [--quiet] + + # Ruff linter + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.286 + hooks: + - id: ruff + name: ruff + files: ^pylossless/ + + # Codespell + - repo: https://github.com/codespell-project/codespell + rev: v2.2.5 + hooks: + - id: codespell + additional_dependencies: + - tomli + files: ^pylossless/|^docs/ + types_or: [python, bib, rst, inc] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..501abeb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[tool.codespell] +skip = "docs/source/generated" + +[tool.ruff] +select = ["E", "F", "W", "D"] # pycodestle, pyflakes, Warning, Docstring +exclude = ["__init__.py"] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D413", # Missing blank line after last section +] + +[tool.ruff.pydocstyle] +convention = "numpy" +ignore-decorators = [ + "property", + "setter", +] + +[tool.black] +exclude = "(dist/)|(build/)|(.*\\.ipynb)" # Exclude build artifacts and notebooks \ No newline at end of file diff --git a/requirements_testing.txt b/requirements_testing.txt index 3c23323..06a30fb 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -1,7 +1,11 @@ -pytest -flake8 +black codespell -pydocstyle -coverage dash[testing] -selenium \ No newline at end of file +numpydoc +pre-commit +pytest +pytest-cov +pydocstyle +ruff +selenium +tomli; python_version<'3.11' \ No newline at end of file From 718a485d5f09d40c9e79e6a42714fda800012a39 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 1 Sep 2023 11:01:07 -0400 Subject: [PATCH 03/17] MAINT, STY: reformat all files with Black - Black has now autoformatted all files - and remaining flakes have been fixed --- docs/examples/usage.py | 1 + docs/source/conf.py | 62 +-- pylossless/_logging.py | 17 +- pylossless/_utils.py | 4 +- pylossless/bids.py | 79 ++-- pylossless/config.py | 11 +- pylossless/dash/__init__.py | 15 +- pylossless/dash/app.py | 6 +- pylossless/dash/css_defaults.py | 157 ++++--- pylossless/dash/mne_visualizer.py | 375 ++++++++------- pylossless/dash/pylossless_qc.py | 21 +- pylossless/dash/qcannotations.py | 74 +-- pylossless/dash/qcgui.py | 351 +++++++------- pylossless/dash/tests/conftest.py | 4 +- pylossless/dash/tests/test_topo_viz.py | 28 +- pylossless/dash/topo_viz.py | 350 ++++++++------ pylossless/flagging.py | 42 +- pylossless/pipeline.py | 606 +++++++++++++------------ pylossless/tests/test_pipeline.py | 88 ++-- pylossless/tests/test_simulated.py | 131 +++--- setup.py | 22 +- 21 files changed, 1335 insertions(+), 1109 deletions(-) diff --git a/docs/examples/usage.py b/docs/examples/usage.py index 17a5439..b5276c9 100644 --- a/docs/examples/usage.py +++ b/docs/examples/usage.py @@ -17,5 +17,6 @@ # Then, we import the function we need. # For use in jupyter notebooks, We just need to import a single function. from pylossless.dash.app import get_app + app = get_app(kind="jupyter") app.run_server(mode="inline") diff --git a/docs/source/conf.py b/docs/source/conf.py index fa5c90e..17a04eb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,15 +9,15 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'pyLossless' +project = "pyLossless" copyright = "2023, Huberty, Scott; O'Reilly, Christian; Desjardins, James" author = "Huberty, Scott; O'reilly, Christian" -release = '0.1' +release = "0.1" # Point Sphinx.ext.autodoc to the our python modules (two parent directories # from this dir) -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -29,13 +29,15 @@ # gallery for: building tutorial .rst files from python files # sphinxemoji So we can use emoji's in docs. # sphinx design to support certain directives, like ::grid etc. -extensions = ['sphinx.ext.intersphinx', - 'sphinx.ext.autodoc', - 'numpydoc', - 'sphinx.ext.todo', - 'sphinx_gallery.gen_gallery', - 'sphinxemoji.sphinxemoji', - "sphinx_design"] +extensions = [ + "sphinx.ext.intersphinx", + "sphinx.ext.autodoc", + "numpydoc", + "sphinx.ext.todo", + "sphinx_gallery.gen_gallery", + "sphinxemoji.sphinxemoji", + "sphinx_design", +] # Allows us to use the ..todo:: directive todo_include_todos = True @@ -44,45 +46,43 @@ # Source directory of python file tutorials and the target # directory for the converted rST files sphinx_gallery_conf = { - 'examples_dirs': '../examples', # path to your example scripts - 'gallery_dirs': './generated/auto_tutorials', # path to save tutorials + "examples_dirs": "../examples", # path to your example scripts + "gallery_dirs": "./generated/auto_tutorials", # path to save tutorials } -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] # -- Intersphinx configuration ----------------------------------------------- intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'pandas': ('https://pandas.pydata.org/docs/', None), - 'xarray': ('https://docs.xarray.dev/en/stable/', None), - 'mne': ('https://mne.tools/dev', None), - 'mne_icalabel': ('https://mne.tools/mne-icalabel/dev', None) + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "pandas": ("https://pandas.pydata.org/docs/", None), + "xarray": ("https://docs.xarray.dev/en/stable/", None), + "mne": ("https://mne.tools/dev", None), + "mne_icalabel": ("https://mne.tools/mne-icalabel/dev", None), } # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'pydata_sphinx_theme' -html_static_path = ['_static'] +html_theme = "pydata_sphinx_theme" +html_static_path = ["_static"] html_theme_options = { - "logo": { - "image_light": "logo-lightmode_color.png", - "image_dark": "logo_white.png", - } + "logo": { + "image_light": "logo-lightmode_color.png", + "image_dark": "logo_white.png", + } } # user made CSS to customize look html_css_files = [ - 'css/custom.css', + "css/custom.css", ] # Custom sidebar templates, maps document names to template names. -html_sidebars = { - "index": ["search-field.html", "sidebar-nav-bs", 'globaltoc.html'] -} +html_sidebars = {"index": ["search-field.html", "sidebar-nav-bs", "globaltoc.html"]} # NumPyDoc configuration ----------------------------------------------------- @@ -96,7 +96,7 @@ numpydoc_xref_param_type = True numpydoc_validate = True # Only generate documentation for public members -autodoc_default_flags = ['members', 'undoc-members', 'inherited-members'] +autodoc_default_flags = ["members", "undoc-members", "inherited-members"] numpydoc_class_members_toctree = False numpydoc_xref_aliases = { @@ -126,5 +126,5 @@ "in", "dtype", "object", - "LosslessPipeline" + "LosslessPipeline", } diff --git a/pylossless/_logging.py b/pylossless/_logging.py index bf5c7d2..c44a98b 100644 --- a/pylossless/_logging.py +++ b/pylossless/_logging.py @@ -3,17 +3,16 @@ from mne.utils import logger from functools import wraps -CONF_MAP = {'run_staging_script': 'staging_script', - 'find_breaks': 'find_breaks'} +CONF_MAP = {"run_staging_script": "staging_script", "find_breaks": "find_breaks"} def _is_step_run(func, instance): """Check if step was run.""" - if func.__name__ in ['run_staging_script', 'find_breaks']: + if func.__name__ in ["run_staging_script", "find_breaks"]: step = CONF_MAP[func.__name__] - return (True if - step not in instance.config or not instance.config[step] - else False) + return ( + True if step not in instance.config or not instance.config[step] else False + ) else: return False @@ -49,14 +48,15 @@ def wrapper(*args, message=None, **kwargs): end_time = time.time() dur = f"{end_time - start_time:.2f}" if verbose: - logger.info(f"LOSSLESS: 🏁 Finished {this_step} after {dur}" - " seconds.") + logger.info(f"LOSSLESS: 🏁 Finished {this_step} after {dur}" " seconds.") return result + return wrapper def lossless_time(func): """Log the time of a full pipeline run.""" + def wrapper(*args, **kwargs): logger.info(" ⏩ LOSSLESS: Starting Pylossless Pipeline.") start_time = time.time() @@ -65,4 +65,5 @@ def wrapper(*args, **kwargs): dur = f"{(end_time - start_time) / 60:.2f}" logger.info(f" ✅ LOSSLESS: Pipeline completed! took {dur} minutes.") return result + return wrapper diff --git a/pylossless/_utils.py b/pylossless/_utils.py index 3a8ffc9..2cf5f4d 100644 --- a/pylossless/_utils.py +++ b/pylossless/_utils.py @@ -12,7 +12,7 @@ def _icalabel_to_data_frame(ica): """Export IClabels to pandas DataFrame.""" - ic_type = [''] * ica.n_components_ + ic_type = [""] * ica.n_components_ for label, comps in ica.labels_.items(): for comp in comps: ic_type[comp] = label @@ -23,6 +23,6 @@ def _icalabel_to_data_frame(ica): component=ica._ica_names, annotator=["ic_label"] * ica.n_components_, ic_type=ic_type, - confidence=ica.labels_scores_.max(1) + confidence=ica.labels_scores_.max(1), ) ) diff --git a/pylossless/bids.py b/pylossless/bids.py index 71fc23f..90ba861 100644 --- a/pylossless/bids.py +++ b/pylossless/bids.py @@ -10,9 +10,7 @@ # TODO: Add parameters and return. -def get_bids_path(bids_path_kwargs, - datatype='eeg', - bids_root='./bids_dataset'): +def get_bids_path(bids_path_kwargs, datatype="eeg", bids_root="./bids_dataset"): """Get BIDS path from BIDS recording.""" if "datatype" not in bids_path_kwargs: bids_path_kwargs["datatype"] = datatype @@ -23,17 +21,23 @@ def get_bids_path(bids_path_kwargs, # TODO: Add parameters and return. -def get_dataset_bids_path(bids_path_args, - datatype='eeg', - bids_root='./bids_dataset'): +def get_dataset_bids_path(bids_path_args, datatype="eeg", bids_root="./bids_dataset"): """Getter method for BIDS path from BIDS dataset.""" - return [get_bids_path(bids_path_kwargs, datatype, bids_root) - for bids_path_kwargs in bids_path_args] - - -def convert_recording_to_bids(import_func, import_kwargs, bids_path_kwargs, - datatype='eeg', bids_root='./bids_dataset', - import_events=True, **write_kwargs): + return [ + get_bids_path(bids_path_kwargs, datatype, bids_root) + for bids_path_kwargs in bids_path_args + ] + + +def convert_recording_to_bids( + import_func, + import_kwargs, + bids_path_kwargs, + datatype="eeg", + bids_root="./bids_dataset", + import_events=True, + **write_kwargs +): """Convert a dataset to BIDS. Parameters @@ -54,6 +58,7 @@ def convert_recording_to_bids(import_func, import_kwargs, bids_path_kwargs, constructor of the `mne_bids.BIDSPath` class. import events: boolean Whether to import a provided events object + Returns ------- bids_paths : list of instance of `mne_bids.BIDSPath` @@ -72,16 +77,22 @@ def convert_recording_to_bids(import_func, import_kwargs, bids_path_kwargs, if "allow_preload" not in write_kwargs: write_kwargs["allow_preload"] = True - write_raw_bids(raw, bids_path=bids_path, - events_data=events, event_id=event_id, - **write_kwargs) + write_raw_bids( + raw, bids_path=bids_path, events_data=events, event_id=event_id, **write_kwargs + ) return bids_path -def convert_dataset_to_bids(import_funcs, import_args, bids_path_args, - datatype='eeg', bids_root='./bids_dataset', - import_events=True, **write_kwargs): +def convert_dataset_to_bids( + import_funcs, + import_args, + bids_path_args, + datatype="eeg", + bids_root="./bids_dataset", + import_events=True, + **write_kwargs +): """Convert a dataset to BIDS. Parameters @@ -109,28 +120,32 @@ def convert_dataset_to_bids(import_funcs, import_args, bids_path_args, as import_args. import events: boolean Whether to import a provided events object. + Returns ------- bids_paths : list of instance of `mne_bids.BIDSPath` `mne_bids.BIDSPath` for the different recordings """ - assert (len(import_args) == len(bids_path_args)) + assert len(import_args) == len(bids_path_args) if isinstance(import_funcs, list): - assert (len(import_args) == len(import_funcs)) + assert len(import_args) == len(import_funcs) else: - import_funcs = [import_funcs]*len(import_args) + import_funcs = [import_funcs] * len(import_args) bids_paths = [] - for import_kwargs, bids_path_kwargs, func in zip(import_args, - bids_path_args, - import_funcs): + for import_kwargs, bids_path_kwargs, func in zip( + import_args, bids_path_args, import_funcs + ): bids_paths.append( - convert_recording_to_bids(func, - import_kwargs, - bids_path_kwargs, - datatype=datatype, - bids_root=bids_root, - import_events=import_events, - **write_kwargs)) + convert_recording_to_bids( + func, + import_kwargs, + bids_path_kwargs, + datatype=datatype, + bids_root=bids_root, + import_events=import_events, + **write_kwargs + ) + ) return bids_paths diff --git a/pylossless/config.py b/pylossless/config.py index f12f1b1..3436bc5 100644 --- a/pylossless/config.py +++ b/pylossless/config.py @@ -16,15 +16,15 @@ class Config(dict): """Representation of configuration files for a pipeline procedure.""" - DEFAULT_CONFIG_PATH = (Path(__file__).parent / - "assets" / "ll_default_config.yaml") + DEFAULT_CONFIG_PATH = Path(__file__).parent / "assets" / "ll_default_config.yaml" def read(self, file_name): """Read a saved pylossless config file.""" file_name = Path(file_name) if not file_name.exists(): - raise FileExistsError(f'Configuration file {file_name.absolute()} ' - 'does not exist') + raise FileExistsError( + f"Configuration file {file_name.absolute()} " "does not exist" + ) with file_name.open("r") as init_variables_file: self.update(yaml.safe_load(init_variables_file)) @@ -40,8 +40,7 @@ def save(self, file_name): """Save the current Config object to disk.""" file_name = Path(file_name) with file_name.open("w") as init_variables_file: - yaml.dump(dict(self), init_variables_file, - indent=4, sort_keys=True) + yaml.dump(dict(self), init_variables_file, indent=4, sort_keys=True) def print(self): """Print the Config contents.""" diff --git a/pylossless/dash/__init__.py b/pylossless/dash/__init__.py index 73234af..e0d9062 100644 --- a/pylossless/dash/__init__.py +++ b/pylossless/dash/__init__.py @@ -2,7 +2,18 @@ from mne_icalabel.config import ICLABEL_LABELS_TO_MNE -IC_COLORS = ['#2c2c2c', '#003e83', 'cyan', 'goldenrod', 'magenta', '#b08699', - '#96bfe6', 'brown', 'yellowgreen', 'burlywood', 'plum'] +IC_COLORS = [ + "#2c2c2c", + "#003e83", + "cyan", + "goldenrod", + "magenta", + "#b08699", + "#96bfe6", + "brown", + "yellowgreen", + "burlywood", + "plum", +] ic_label_cmap = dict(zip(ICLABEL_LABELS_TO_MNE.values(), IC_COLORS)) diff --git a/pylossless/dash/app.py b/pylossless/dash/app.py index abc417b..cdeb92b 100644 --- a/pylossless/dash/app.py +++ b/pylossless/dash/app.py @@ -12,14 +12,14 @@ def get_app(fpath=None, project_root=None, disable_buttons=False, kind="dash"): """Call either Dash or Jupyter for Lossless QC procedure.""" if kind == "jupyter": from jupyter_dash import JupyterDash + app = JupyterDash(__name__, external_stylesheets=[dbc.themes.SLATE]) else: app = dash.Dash(__name__, external_stylesheets=[dbc.themes.SLATE]) - QCGUI(app, fpath=fpath, project_root=project_root, - disable_buttons=disable_buttons) + QCGUI(app, fpath=fpath, project_root=project_root, disable_buttons=disable_buttons) return app -if __name__ == '__main__': +if __name__ == "__main__": get_app().run_server(debug=False, use_reloader=False) diff --git a/pylossless/dash/css_defaults.py b/pylossless/dash/css_defaults.py index 781bcc0..e38065d 100644 --- a/pylossless/dash/css_defaults.py +++ b/pylossless/dash/css_defaults.py @@ -5,59 +5,72 @@ # Default Layout for individual scatter plots within timeseries ############################################################### -drawn_shapes_format = {"drawdirection": "vertical", - "layer": "below", - "fillcolor": "red", - "opacity": 0.51, - "line": {"width": 0}} - -drawn_selections_format = {'line': dict(color='crimson', width=2)} - - -DEFAULT_LAYOUT_XAXIS = {'zeroline': False, - 'showgrid': True, - 'title': "time (seconds)", - 'gridcolor': 'white', - 'fixedrange': True, - 'showline': True, - 'titlefont': dict(color='#ADB5BD'), - 'tickfont': dict(color='#ADB5BD'), - 'automargin': True - } - -DEFAULT_LAYOUT_YAXIS = {'showgrid': True, - 'showline': True, - 'zeroline': False, - 'autorange': False, # 'reversed', - 'scaleratio': 0.5, - "tickmode": "array", - 'titlefont': dict(color='#ADB5BD'), - 'tickfont': dict(color='#ADB5BD'), - 'fixedrange': True, - 'automargin': True} +drawn_shapes_format = { + "drawdirection": "vertical", + "layer": "below", + "fillcolor": "red", + "opacity": 0.51, + "line": {"width": 0}, +} + +drawn_selections_format = {"line": dict(color="crimson", width=2)} + + +DEFAULT_LAYOUT_XAXIS = { + "zeroline": False, + "showgrid": True, + "title": "time (seconds)", + "gridcolor": "white", + "fixedrange": True, + "showline": True, + "titlefont": dict(color="#ADB5BD"), + "tickfont": dict(color="#ADB5BD"), + "automargin": True, +} + +DEFAULT_LAYOUT_YAXIS = { + "showgrid": True, + "showline": True, + "zeroline": False, + "autorange": False, # 'reversed', + "scaleratio": 0.5, + "tickmode": "array", + "titlefont": dict(color="#ADB5BD"), + "tickfont": dict(color="#ADB5BD"), + "fixedrange": True, + "automargin": True, +} DEFAULT_LAYOUT = dict( # height=400, - # width=850, - xaxis=DEFAULT_LAYOUT_XAXIS, - yaxis=DEFAULT_LAYOUT_YAXIS, - showlegend=False, - margin={'t': 15, 'b': 0, 'l': 35, 'r': 5}, - # {'t': 15,'b': 25, 'l': 35, 'r': 5}, - paper_bgcolor="rgba(0,0,0,0)", - plot_bgcolor="#EAEAF2", - shapes=[], - dragmode='select', - newshape=drawn_shapes_format, - newselection=drawn_selections_format, - activeshape=dict(fillcolor='crimson', opacity=.75), - hovermode='closest') - -WATERMARK_ANNOT = (dict(text='NO FILE SELECTED', - textangle=0, opacity=0.1, - font=dict(color='red', size=80), - xref='paper', yref='paper', x=.5, y=.5, - showarrow=False), - ) + # width=850, + xaxis=DEFAULT_LAYOUT_XAXIS, + yaxis=DEFAULT_LAYOUT_YAXIS, + showlegend=False, + margin={"t": 15, "b": 0, "l": 35, "r": 5}, + # {'t': 15,'b': 25, 'l': 35, 'r': 5}, + paper_bgcolor="rgba(0,0,0,0)", + plot_bgcolor="#EAEAF2", + shapes=[], + dragmode="select", + newshape=drawn_shapes_format, + newselection=drawn_selections_format, + activeshape=dict(fillcolor="crimson", opacity=0.75), + hovermode="closest", +) + +WATERMARK_ANNOT = ( + dict( + text="NO FILE SELECTED", + textangle=0, + opacity=0.1, + font=dict(color="red", size=80), + xref="paper", + yref="paper", + x=0.5, + y=0.5, + showarrow=False, + ), +) CSS = dict() STYLE = dict() @@ -71,78 +84,78 @@ # Empty # bootstrap format for channel slider div: self.channel_slider_div -CSS['ch-slider-div'] = "d-inline-block align-top" +CSS["ch-slider-div"] = "d-inline-block align-top" # bootstrap format for time slider # Empty # bootstrap for time slider div -CSS['time-slider-div'] = "w-100" +CSS["time-slider-div"] = "w-100" # bootstrap for timeseries: self.graph -CSS['timeseries'] = "w-100 d-inline-block" # border border-info -STYLE['timeseries'] = {'height': '40vh'} +CSS["timeseries"] = "w-100 d-inline-block" # border border-info +STYLE["timeseries"] = {"height": "40vh"} # bootstrap for timeseries-div: self.graph_div; border border-warnings -CSS['timeseries-div'] = "mh-100 d-inline-block shadow-lg" -STYLE['timeseries-div'] = {'width': '95%'} +CSS["timeseries-div"] = "mh-100 d-inline-block shadow-lg" +STYLE["timeseries-div"] = {"width": "95%"} # bootstrap for timeseries-container: self.container_plot -CSS['timeseries-container'] = "w-100" # border border-success +CSS["timeseries-container"] = "w-100" # border border-success ############################################################ # TOPO PLOTS ############################################################ # bootstrap format for topo slider div -CSS['topo-slider-div'] = "d-inline-block align-middle" +CSS["topo-slider-div"] = "d-inline-block align-middle" # bootstrap for topo dcc-graph -CSS['topo-dcc'] = "bg-secondary bg-opacity-50 border rounded" # border-info +CSS["topo-dcc"] = "bg-secondary bg-opacity-50 border rounded" # border-info # bootstrap for div containing topo-dcc; border border-warning -CSS['topo-dcc-div'] = 'bg-secondary bg-opacity-50 d-inline-block align-top' -STYLE['topo-dcc-div'] = {'width': '90%'} # so that slider can fit to the left +CSS["topo-dcc-div"] = "bg-secondary bg-opacity-50 d-inline-block align-top" +STYLE["topo-dcc-div"] = {"width": "90%"} # so that slider can fit to the left # bootstrap for final container: self.container_plot -CSS['topo-container'] = 'topo-div shadow-lg' # border border-success +CSS["topo-container"] = "topo-div shadow-lg" # border border-success ##################### # QC DASHBOARD LAYOUT ##################### # QC container -STYLE['qc-container'] = {} # {'border': '2px solid yellow'} +STYLE["qc-container"] = {} # {'border': '2px solid yellow'} # File control Row -CSS['file-row'] = 'h-20' +CSS["file-row"] = "h-20" # STYLE['file-row'] = {'border': '2px solid orange'} # Select Folder & Save Button Col # STYLE['folder-col'] = {'border': '2px solid red'} # Select Folder & Save Buttons -CSS['button'] = "d-md-inline-block" +CSS["button"] = "d-md-inline-block" # Dropdown Col # STYLE['dropdown-col'] = {'border': '2px dashed purple'} # Dropdown -CSS['dropdown'] = "d-md-inline-block w-100" +CSS["dropdown"] = "d-md-inline-block w-100" # Logo Col # Empty # Logo -CSS['logo'] = 'ms-5 mh-100' +CSS["logo"] = "ms-5 mh-100" # Timeseries & Topo bootstrap Row -CSS['plots-row'] = 'h-80' -STYLE['plots-row'] = {} # {'border': '2px dashed pink'} +CSS["plots-row"] = "h-80" +STYLE["plots-row"] = {} # {'border': '2px dashed pink'} # Timeseries plots Col -CSS['timeseries-col'] = 'w-100 h-100 mh-100' +CSS["timeseries-col"] = "w-100 h-100 mh-100" # Topoplots Col -CSS['topo-col'] = 'h-100 mt-2 mb-2' # border border-danger +CSS["topo-col"] = "h-100 mt-2 mb-2" # border border-danger diff --git a/pylossless/dash/mne_visualizer.py b/pylossless/dash/mne_visualizer.py index a040af5..f73e749 100644 --- a/pylossless/dash/mne_visualizer.py +++ b/pylossless/dash/mne_visualizer.py @@ -23,18 +23,29 @@ def _add_watermark_annot(): from .css_defaults import WATERMARK_ANNOT + return WATERMARK_ANNOT class MNEVisualizer: """Visualize an mne.io.raw object in a dash graph.""" - def __init__(self, app, inst, dcc_graph_kwargs=None, - dash_id_suffix=None, - show_time_slider=True, show_ch_slider=True, - scalings='auto', zoom=2, remove_dc=True, - annot_created_callback=None, refresh_inputs=None, - show_n_channels=20, set_callbacks=True): + def __init__( + self, + app, + inst, + dcc_graph_kwargs=None, + dash_id_suffix=None, + show_time_slider=True, + show_ch_slider=True, + scalings="auto", + zoom=2, + remove_dc=True, + annot_created_callback=None, + refresh_inputs=None, + show_n_channels=20, + set_callbacks=True, + ): """Initialize class. Parameters @@ -61,6 +72,7 @@ def __init__(self, app, inst, dcc_graph_kwargs=None, show_time_slider : bool Whether to show the channel slider with the MNEVIsualizer time-series graph. Defaults to True. + Returns ------- an instance of MNEVisualizer. @@ -84,31 +96,42 @@ def __init__(self, app, inst, dcc_graph_kwargs=None, self.annotating_start = None self.annotation_inprogress = None self.annot_created_callback = annot_created_callback - self.new_annot_desc = 'selected_time' + self.new_annot_desc = "selected_time" self.mne_annots = None self.loading_div = None # setting component ids based on dash_id_suffix - default_ids = ['graph', 'ch-slider', 'time-slider', - 'container-plot', 'output', 'mne-annotations', - 'loading', "loading-output"] - self.dash_ids = {id_: (id_ + f'_{dash_id_suffix}') - if dash_id_suffix else id_ - for id_ in default_ids} - modebar_buttons = {'modeBarButtonsToAdd': ["eraseshape"], - 'modeBarButtonsToRemove': ['zoom', 'pan']} - self.dcc_graph_kwargs = dict(id=self.dash_ids['graph'], - className=CSS['timeseries'], - style=STYLE['timeseries'], - figure={'data': None, - 'layout': None}, - config=modebar_buttons) + default_ids = [ + "graph", + "ch-slider", + "time-slider", + "container-plot", + "output", + "mne-annotations", + "loading", + "loading-output", + ] + self.dash_ids = { + id_: (id_ + f"_{dash_id_suffix}") if dash_id_suffix else id_ + for id_ in default_ids + } + modebar_buttons = { + "modeBarButtonsToAdd": ["eraseshape"], + "modeBarButtonsToRemove": ["zoom", "pan"], + } + self.dcc_graph_kwargs = dict( + id=self.dash_ids["graph"], + className=CSS["timeseries"], + style=STYLE["timeseries"], + figure={"data": None, "layout": None}, + config=modebar_buttons, + ) if dcc_graph_kwargs is not None: self.dcc_graph_kwargs.update(dcc_graph_kwargs) self.graph = dcc.Graph(**self.dcc_graph_kwargs) - self.graph_div = html.Div([self.graph], - style=STYLE['timeseries-div'], - className=CSS['timeseries-div']) + self.graph_div = html.Div( + [self.graph], style=STYLE["timeseries-div"], className=CSS["timeseries-div"] + ) self.show_time_slider = show_time_slider self.show_ch_slider = show_ch_slider self.inst = inst @@ -137,8 +160,7 @@ def load_recording(self, raw): marks_keys = np.round(np.linspace(self.times[0], self.times[-1], 10)) self.time_slider.min = self.times[0] self.time_slider.max = self.times[-1] - self.win_size - self.time_slider.marks = {int(key): str(int(key)) - for key in marks_keys} + self.time_slider.marks = {int(key): str(int(key)) for key in marks_keys} self.initialize_shapes() self.update_layout() @@ -151,22 +173,24 @@ def inst(self): def inst(self, inst): if not inst: return - _validate_type(inst, (BaseEpochs, BaseRaw, Evoked), 'inst') + _validate_type(inst, (BaseEpochs, BaseRaw, Evoked), "inst") self._inst = inst self.inst.load_data() - self.scalings = dict(mag=1e-12, - grad=4e-11, - eeg=20e-6, - eog=150e-6, - ecg=5e-4, - emg=1e-3, - ref_meg=1e-12, - misc=1e-3, - stim=1, - resp=1, - chpi=1e-4, - whitened=1e2) - if self.scalings_arg == 'auto': + self.scalings = dict( + mag=1e-12, + grad=4e-11, + eeg=20e-6, + eog=150e-6, + ecg=5e-4, + emg=1e-3, + ref_meg=1e-12, + misc=1e-3, + stim=1, + resp=1, + chpi=1e-4, + whitened=1e2, + ) + if self.scalings_arg == "auto": for kind in np.unique(self.inst.get_channel_types()): self.scalings[kind] = np.percentile(self.inst.get_data(), 99.5) else: @@ -184,8 +208,7 @@ def initialize_shapes(self): """Make graph.layout.shapes for each mne.io.raw.annotation.""" if not self.inst: return - self.mne_annots.data = EEGAnnotationList.from_mne_inst(self.inst, - self.layout) + self.mne_annots.data = EEGAnnotationList.from_mne_inst(self.inst, self.layout) def refresh_shapes(self): """Identify shapes that are viewable in the current time-window.""" @@ -206,47 +229,50 @@ def update_inst_annnotations(self): annots = self.mne_annots.data.to_mne_annotation() self.inst.set_annotations(annots) -############################ -# Create Timeseries Layouts -############################ + ############################ + # Create Timeseries Layouts + ############################ @property def layout(self): """Return MNEVIsualizer.graph.figure.layout.""" - return self.graph.figure['layout'] + return self.graph.figure["layout"] @layout.setter def layout(self, layout): - self.graph.figure['layout'] = layout + self.graph.figure["layout"] = layout def initialize_layout(self): """Create MNEVisualizer.graph.figure.layout.""" if not self.inst: - DEFAULT_LAYOUT['annotations'] = _add_watermark_annot() + DEFAULT_LAYOUT["annotations"] = _add_watermark_annot() tickvals_handler = np.arange(-self.n_sel_ch + 1, 1) - DEFAULT_LAYOUT['yaxis'].update({"tickvals": tickvals_handler, - 'ticktext': [''] * self.n_sel_ch, - 'range': [-self.n_sel_ch, 1]}) + DEFAULT_LAYOUT["yaxis"].update( + { + "tickvals": tickvals_handler, + "ticktext": [""] * self.n_sel_ch, + "range": [-self.n_sel_ch, 1], + } + ) tmin = self.win_start tmax = self.win_start + self.win_size - DEFAULT_LAYOUT['xaxis'].update({'range': [tmin, tmax]}) + DEFAULT_LAYOUT["xaxis"].update({"range": [tmin, tmax]}) self.layout = go.Layout(**DEFAULT_LAYOUT) - trace_kwargs = {'x': [], - 'y': [], - 'mode': 'lines', - 'line': dict(color='#2c2c2c', width=1) - } + trace_kwargs = { + "x": [], + "y": [], + "mode": "lines", + "line": dict(color="#2c2c2c", width=1), + } # create objects for layout and traces - self.traces = [go.Scatter(name=ii, **trace_kwargs) - for ii in range(self.n_sel_ch)] - self.update_layout(ch_slider_val=self.channel_slider.max, - time_slider_val=0) - - def update_layout(self, - ch_slider_val=None, - time_slider_val=None): + self.traces = [ + go.Scatter(name=ii, **trace_kwargs) for ii in range(self.n_sel_ch) + ] + self.update_layout(ch_slider_val=self.channel_slider.max, time_slider_val=0) + + def update_layout(self, ch_slider_val=None, time_slider_val=None): """Update MNEVisualizer.graph.figure.layout.""" if not self.inst: return @@ -256,7 +282,7 @@ def update_layout(self, self.win_start = time_slider_val tmin, tmax = self.win_start, self.win_start + self.win_size - self.layout.xaxis.update({'range': [tmin, tmax]}) + self.layout.xaxis.update({"range": [tmin, tmax]}) # Update selected channels first_sel_ch = self._ch_slider_val - self.n_sel_ch + 1 @@ -273,28 +299,29 @@ def update_layout(self, # Update the raw timeseries traces ch_names = self.inst.ch_names[::-1][first_sel_ch:last_sel_ch] - self.layout.yaxis['ticktext'] = ch_names + self.layout.yaxis["ticktext"] = ch_names ch_types_list = self.inst.get_channel_types() ch_types = ch_types_list[::-1][first_sel_ch:last_sel_ch] - ch_zip = zip(range(1, self.n_sel_ch+1), ch_names, - data, self.traces, ch_types) + ch_zip = zip(range(1, self.n_sel_ch + 1), ch_names, data, self.traces, ch_types) for i, ch_name, signal, trace, ch_type in ch_zip: trace.x = np.round(times, 3) step_trace = signal / self._get_norm_factor(ch_type) trace.y = step_trace + i - self.n_sel_ch trace.name = ch_name - if ch_name in self.inst.info['bads']: - trace.line.color = '#d3d3d3' + if ch_name in self.inst.info["bads"]: + trace.line.color = "#d3d3d3" else: - trace.line.color = '#2c2c2c' + trace.line.color = "#2c2c2c" # Hover template will show Channel number and Time trace.text = np.round(signal * 1e6, 3) # Volts to microvolts - trace.hovertemplate = (f'Channel: {ch_name}
' + - 'Time: %{x}s
' + - 'Amplitude: %{text}uV
' + - '') + trace.hovertemplate = ( + f"Channel: {ch_name}
" + + "Time: %{x}s
" + + "Amplitude: %{text}uV
" + + "" + ) - self.graph.figure['data'] = self.traces + self.graph.figure["data"] = self.traces self.refresh_shapes() @@ -304,85 +331,88 @@ def update_layout(self, def set_callback(self): """Set the dash callback for the MNE.Visualizer object.""" - args = [Output(self.dash_ids['graph'], 'figure'), - Input(self.dash_ids['ch-slider'], 'value'), - Input(self.dash_ids['time-slider'], 'value'), - Input(self.dash_ids['graph'], "clickData"), - Input(self.dash_ids['graph'], "relayoutData"), - ] + args = [ + Output(self.dash_ids["graph"], "figure"), + Input(self.dash_ids["ch-slider"], "value"), + Input(self.dash_ids["time-slider"], "value"), + Input(self.dash_ids["graph"], "clickData"), + Input(self.dash_ids["graph"], "relayoutData"), + ] if self.refresh_inputs: args += self.refresh_inputs - @self.app.callback(*args, suppress_callback_exceptions=False, - prevent_initial_call=False) + @self.app.callback( + *args, suppress_callback_exceptions=False, prevent_initial_call=False + ) def callback(ch, time, click_data, relayout_data, *args): if not self.inst: return dash.no_update - update_layout_ids = [self.dash_ids['ch-slider'], - self.dash_ids['time-slider'], - ] + update_layout_ids = [ + self.dash_ids["ch-slider"], + self.dash_ids["time-slider"], + ] if self.refresh_inputs: - update_layout_ids.extend([inp.component_id - for inp - in self.refresh_inputs]) + update_layout_ids.extend( + [inp.component_id for inp in self.refresh_inputs] + ) update_layout = False ctx = dash.callback_context - events = [event['prop_id'].split('.') for event in ctx.triggered - if len(event['prop_id'].split('.')) == 2] + events = [ + event["prop_id"].split(".") + for event in ctx.triggered + if len(event["prop_id"].split(".")) == 2 + ] for object_, dash_event in events: - - if object_ == self.dash_ids['graph']: - if dash_event == 'clickData': + if object_ == self.dash_ids["graph"]: + if dash_event == "clickData": # Working on traces - logger.debug('** Trace selected') + logger.debug("** Trace selected") c_index = click_data["points"][0]["curveNumber"] ch_name = self.traces[c_index].name - if ch_name in self.inst.info['bads']: - self.inst.info['bads'].pop() + if ch_name in self.inst.info["bads"]: + self.inst.info["bads"].pop() else: - self.inst.info['bads'].append(ch_name) + self.inst.info["bads"].append(ch_name) update_layout = True - elif dash_event == 'relayoutData': + elif dash_event == "relayoutData": # Working on annotations - logger.debug(f'** relayoutData: {relayout_data}') + logger.debug(f"** relayoutData: {relayout_data}") if "selections" in relayout_data: # shape creation - logger.debug('** shape created') - onset = relayout_data["selections"][0]['x0'] - offset = relayout_data["selections"][0]['x1'] + logger.debug("** shape created") + onset = relayout_data["selections"][0]["x0"] + offset = relayout_data["selections"][0]["x1"] description = self.new_annot_desc - annot = EEGAnnotation(onset, offset-onset, - description, self.layout) + annot = EEGAnnotation( + onset, offset - onset, description, self.layout + ) self.mne_annots.data.append(annot) elif "shapes" in relayout_data: # shape was deleted - logger.debug('** shape deleted') - updated_shapes = relayout_data['shapes'] + logger.debug("** shape deleted") + updated_shapes = relayout_data["shapes"] if len(updated_shapes) < len(self.layout.shapes): # Shape (i.e. annotation) was deleted - previous_names = [shape['name'] for - shape in self.layout.shapes] - new_names = [shape['name'] for - shape in updated_shapes] + previous_names = [ + shape["name"] for shape in self.layout.shapes + ] + new_names = [shape["name"] for shape in updated_shapes] deleted = set(previous_names) - set(new_names) self.mne_annots.data.remove(deleted.pop()) - elif any([key.endswith('x0') - for key in relayout_data.keys()]): + elif any([key.endswith("x0") for key in relayout_data.keys()]): # shape was modified - logger.debug('** shape modified') - shape_str = (list(relayout_data.keys())[0] - .split(".")[0] - ) + logger.debug("** shape modified") + shape_str = list(relayout_data.keys())[0].split(".")[0] x0 = relayout_data[f"{shape_str}.x0"] x1 = relayout_data[f"{shape_str}.x1"] - shape_i = int(shape_str.split('[', 1)[1][:-1]) - name = self.layout.shapes[shape_i]['name'] + shape_i = int(shape_str.split("[", 1)[1][:-1]) + name = self.layout.shapes[shape_i]["name"] if name in self.mne_annots.data: annot = self.mne_annots.data[name] annot.onset = x0 @@ -421,38 +451,42 @@ def times(self): def _init_sliders(self): """Initialize the Channel and Time dcc.Slider components.""" - self.channel_slider = dcc.Slider(id=self.dash_ids["ch-slider"], - min=self.n_sel_ch - 1, - max=self.nb_channels - 1, - step=1, - marks=None, - value=self.nb_channels - 1, - included=False, - updatemode='mouseup', - vertical=True, - verticalHeight=300) - self.channel_slider_div = html.Div(self.channel_slider, - className=CSS['ch-slider-div'], - style={}) + self.channel_slider = dcc.Slider( + id=self.dash_ids["ch-slider"], + min=self.n_sel_ch - 1, + max=self.nb_channels - 1, + step=1, + marks=None, + value=self.nb_channels - 1, + included=False, + updatemode="mouseup", + vertical=True, + verticalHeight=300, + ) + self.channel_slider_div = html.Div( + self.channel_slider, className=CSS["ch-slider-div"], style={} + ) if not self.show_ch_slider: - self.channel_slider_div.style.update({'display': 'none'}) + self.channel_slider_div.style.update({"display": "none"}) marks_keys = np.round(np.linspace(self.times[0], self.times[-1], 10)) marks_dict = {int(key): str(int(key)) for key in marks_keys} max_ = self.times[-1] - self.win_size - self.time_slider = dcc.Slider(id=self.dash_ids['time-slider'], - min=self.times[0], - max=max_ if max_ > 0 else 0, - marks=marks_dict, - value=self.win_start, - vertical=False, - included=False, - updatemode='mouseup') - self.time_slider_div = html.Div(self.time_slider, - className=CSS['time-slider-div'], - style={}) + self.time_slider = dcc.Slider( + id=self.dash_ids["time-slider"], + min=self.times[0], + max=max_ if max_ > 0 else 0, + marks=marks_dict, + value=self.win_start, + vertical=False, + included=False, + updatemode="mouseup", + ) + self.time_slider_div = html.Div( + self.time_slider, className=CSS["time-slider-div"], style={} + ) if not self.show_time_slider: - self.time_slider_div.style.update({'display': 'none'}) + self.time_slider_div.style.update({"display": "none"}) def _init_annot_store(self): """Initialize the dcc.Store component of mne annotations.""" @@ -461,9 +495,9 @@ def _init_annot_store(self): def _set_loading_icon(self): """Add the loading icon.""" loading = dcc.Loading( - id=self.dash_ids['loading'], + id=self.dash_ids["loading"], type="circle", - children=html.Div(id=self.dash_ids['loading-output']) + children=html.Div(id=self.dash_ids["loading-output"]), ) self.loading_div = html.Div(loading) self.graph_div.children.append(self.loading_div) @@ -472,14 +506,18 @@ def _set_div(self): """Build the final html.Div component to be returned.""" # include both the timeseries graph and the sliders # note that the order of components is important - graph_components = [self.channel_slider_div, - self.graph_div, - self.time_slider_div, - self.mne_annots] + graph_components = [ + self.channel_slider_div, + self.graph_div, + self.time_slider_div, + self.mne_annots, + ] # pass the list of components into an html.Div - self.container_plot = html.Div(id=self.dash_ids['container-plot'], - className=CSS['timeseries-container'], - children=graph_components) + self.container_plot = html.Div( + id=self.dash_ids["container-plot"], + className=CSS["timeseries-container"], + children=graph_components, + ) class ICVisualizer(MNEVisualizer): @@ -521,7 +559,7 @@ def __init__(self, raw, *args, cmap=None, ic_types=None, **kwargs): an instance of ICVisualizer. Notes - ---- + ----- Any arguments that can be passed to MNEVisualizer can also be passed to ICVisualizer. """ @@ -548,14 +586,11 @@ def load_recording(self, raw, cmap=None, ic_types=None): super(ICVisualizer, self).load_recording(raw) - def update_layout(self, - ch_slider_val=None, - time_slider_val=None): + def update_layout(self, ch_slider_val=None, time_slider_val=None): """Update raw timeseries layout.""" if not self.inst: return - super(ICVisualizer, self).update_layout(ch_slider_val, - time_slider_val) + super(ICVisualizer, self).update_layout(ch_slider_val, time_slider_val) # Update selected channels first_sel_ch = self._ch_slider_val - self.n_sel_ch + 1 @@ -566,16 +601,18 @@ def update_layout(self, self.ic_types ch_names = self.inst.ch_names[::-1][first_sel_ch:last_sel_ch] for ch_name, trace in zip(ch_names, self.traces): - if ch_name in self.inst.info['bads']: - trace.line.color = '#d3d3d3' + if ch_name in self.inst.info["bads"]: + trace.line.color = "#d3d3d3" else: trace.line.color = self.cmap[ch_name] # IC Hover template will show IC number and Time by default - trace.hovertemplate = (f'Component: {ch_name}' + - '
Time: %{x}s
' + - '') + trace.hovertemplate = ( + f"Component: {ch_name}" + + "
Time: %{x}s
" + + "" + ) if self.ic_types: # update hovertemplate with IC label label = self.ic_types[ch_name] - trace.hovertemplate += f'Label: {label}
' - self.graph.figure['data'] = self.traces + trace.hovertemplate += f"Label: {label}
" + self.graph.figure["data"] = self.traces diff --git a/pylossless/dash/pylossless_qc.py b/pylossless/dash/pylossless_qc.py index c0be74f..378381e 100644 --- a/pylossless/dash/pylossless_qc.py +++ b/pylossless/dash/pylossless_qc.py @@ -7,28 +7,29 @@ import argparse from pylossless.dash.app import get_app -desc = 'Launch QCR dashboard with optional directory and filename arguments.' +desc = "Launch QCR dashboard with optional directory and filename arguments." def launch_dash_app(directory=None, filepath=None, disable_buttons=False): """Launch dashboard.""" - app = get_app(fpath=filepath, project_root=directory, - disable_buttons=disable_buttons) + app = get_app( + fpath=filepath, project_root=directory, disable_buttons=disable_buttons + ) app.run_server(debug=True) def main(): """Parse arguments for CLI.""" - disable_button_help = ('If included, Folder and Save buttons are' - ' deactivated') + disable_button_help = "If included, Folder and Save buttons are" " deactivated" parser = argparse.ArgumentParser(description=desc) - parser.add_argument('--directory', help='path to the project folder') - parser.add_argument('--filepath', help='path to the EDF file to load') - parser.add_argument('--disable_buttons', action='store_true', - help=disable_button_help) + parser.add_argument("--directory", help="path to the project folder") + parser.add_argument("--filepath", help="path to the EDF file to load") + parser.add_argument( + "--disable_buttons", action="store_true", help=disable_button_help + ) args = parser.parse_args() launch_dash_app(args.directory, args.filepath, args.disable_buttons) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pylossless/dash/qcannotations.py b/pylossless/dash/qcannotations.py index 04bf002..ebdcf74 100644 --- a/pylossless/dash/qcannotations.py +++ b/pylossless/dash/qcannotations.py @@ -36,25 +36,29 @@ def __init__(self, onset, duration, description_str, layout): self._description = description_str self._dash_layout = layout - self._dash_shape = dict(name=self.id, - type="rect", - xref="x", - yref="y", - x0=self.onset, - y0=self._dash_layout.yaxis['range'][0], - x1=self.onset + self.duration, - y1=self._dash_layout.yaxis['range'][1], - fillcolor='red', - opacity=0.25 if self.duration else .75, - line_width=1, - line_color='black', - layer="below" if self.duration else 'above') - self._dash_description = dict(x=self.onset + self.duration / 2, - y=self._dash_layout.yaxis['range'][1], - text=self.description, - showarrow=False, - yshift=10, - font={'color': '#F1F1F1'}) + self._dash_shape = dict( + name=self.id, + type="rect", + xref="x", + yref="y", + x0=self.onset, + y0=self._dash_layout.yaxis["range"][0], + x1=self.onset + self.duration, + y1=self._dash_layout.yaxis["range"][1], + fillcolor="red", + opacity=0.25 if self.duration else 0.75, + line_width=1, + line_color="black", + layer="below" if self.duration else "above", + ) + self._dash_description = dict( + x=self.onset + self.duration / 2, + y=self._dash_layout.yaxis["range"][1], + text=self.description, + showarrow=False, + yshift=10, + font={"color": "#F1F1F1"}, + ) def update_dash_objects(self): """Update plotly shape/annotations. @@ -68,7 +72,7 @@ def update_dash_objects(self): """ self._dash_shape["x0"] = self.onset self._dash_shape["x1"] = self.onset + self.duration - self._dash_shape["opacity"] = 0.25 if self.duration else .75 + self._dash_shape["opacity"] = 0.25 if self.duration else 0.75 self._dash_description["x"] = self.onset + self.duration / 2 self._dash_description["text"] = self.description @@ -133,10 +137,9 @@ def description(self, description): @staticmethod def from_mne_annotation(annot, layout): """Create an EEGAnnotation instance from an mne.annotation.""" - return EEGAnnotation(annot["onset"], - annot["duration"], - annot["description"], - layout) + return EEGAnnotation( + annot["onset"], annot["duration"], annot["description"], layout + ) def set_editable(self, editable=True): """Set the editable property of the layout.shape for this instance. @@ -165,16 +168,16 @@ def __init__(self, annotations=None): """ if annotations is not None: if isinstance(annotations, list): - self.annotations = pd.Series({annot.id: annot - for annot in annotations}) + self.annotations = pd.Series({annot.id: annot for annot in annotations}) else: self.annotations = pd.Series(annotations) else: self.annotations = pd.Series() def __get_series(self, attr): - return pd.Series({annot.id: getattr(annot, attr) - for annot in self.annotations.values}) + return pd.Series( + {annot.id: getattr(annot, attr) for annot in self.annotations.values} + ) @property def durations(self): @@ -228,10 +231,11 @@ def pick(self, tmin=0, tmax=np.inf): annot_tmin = self.onsets annot_tmax = annot_tmin + self.durations - mask = (((tmin <= annot_tmin) & (annot_tmin < tmax)) | - ((tmin < annot_tmax) & (annot_tmax <= tmax)) | - ((annot_tmin < tmin) & (annot_tmax > tmax)) - ) + mask = ( + ((tmin <= annot_tmin) & (annot_tmin < tmax)) + | ((tmin < annot_tmax) & (annot_tmax <= tmax)) + | ((annot_tmin < tmin) & (annot_tmax > tmax)) + ) return EEGAnnotationList(self.annotations[mask]) def remove(self, id_): @@ -241,8 +245,10 @@ def remove(self, id_): @staticmethod def from_mne_inst(inst, layout): """Create EEGAnnotationList object from a list of mne.annotations.""" - annots = [EEGAnnotation.from_mne_annotation(annot, layout) - for annot in inst.annotations] + annots = [ + EEGAnnotation.from_mne_annotation(annot, layout) + for annot in inst.annotations + ] return EEGAnnotationList(annots) def __len__(self): diff --git a/pylossless/dash/qcgui.py b/pylossless/dash/qcgui.py index 9e0465d..3abc29e 100644 --- a/pylossless/dash/qcgui.py +++ b/pylossless/dash/qcgui.py @@ -45,10 +45,9 @@ def open_folder_dialog(): class QCGUI: """Class that stores the visualizer-plots that are used in the qcr app.""" - def __init__(self, app, - fpath=None, project_root=None, - disable_buttons=False, - verbose=False): + def __init__( + self, app, fpath=None, project_root=None, disable_buttons=False, verbose=False + ): """Initialize class. Parameters @@ -69,7 +68,7 @@ def __init__(self, app, # TODO: Fix this pathing indexing, can likely cause errors. if project_root is None: project_root = Path(__file__).parent.parent - project_root = project_root / 'assets' / 'test_data' + project_root = project_root / "assets" / "test_data" self.project_root = Path(project_root) self.pipeline = LosslessPipeline() @@ -92,65 +91,73 @@ def set_visualizers(self): """Create EEG/ICA time-series dcc.graphs and topomap dcc.graphs.""" # Setting time-series and topomap visualizers if self.ic_types: - cmap = {ic: ic_label_cmap[ic_type] - for ic, ic_type in self.ic_types.items()} + cmap = {ic: ic_label_cmap[ic_type] for ic, ic_type in self.ic_types.items()} else: cmap = None # Using the output of the callback being triggered by # a selection of a new file, so that the two callbacks # are executed sequentially. - refresh_inputs = [Input('file-dropdown', 'placeholder')] + refresh_inputs = [Input("file-dropdown", "placeholder")] self.ica_visualizer = ICVisualizer( - self.app, self.raw_ica, - dash_id_suffix='ica', + self.app, + self.raw_ica, + dash_id_suffix="ica", cmap=cmap, ic_types=self.ic_types, refresh_inputs=refresh_inputs, - set_callbacks=False) + set_callbacks=False, + ) self.eeg_visualizer = MNEVisualizer( self.app, self.raw, refresh_inputs=refresh_inputs.copy(), show_time_slider=False, - set_callbacks=False) + set_callbacks=False, + ) - input_ = Input(self.eeg_visualizer.dash_ids['graph'], "relayoutData") + input_ = Input(self.eeg_visualizer.dash_ids["graph"], "relayoutData") self.ica_visualizer.refresh_inputs.append(input_) - input_ = Input(self.ica_visualizer.dash_ids['graph'], "relayoutData") + input_ = Input(self.ica_visualizer.dash_ids["graph"], "relayoutData") self.eeg_visualizer.refresh_inputs.append(input_) self.ica_visualizer.set_callback() self.eeg_visualizer.set_callback() self.ica_visualizer.mne_annots = self.eeg_visualizer.mne_annots - self.ica_visualizer.dash_ids['mne-annotations'] = \ - self.eeg_visualizer.dash_ids['mne-annotations'] + self.ica_visualizer.dash_ids["mne-annotations"] = self.eeg_visualizer.dash_ids[ + "mne-annotations" + ] montage = self.raw.get_montage() if self.raw else None - self.ica_topo = TopoVizICA(self.app, montage, self.ica, self.ic_types, - show_sensors=True, - refresh_inputs=refresh_inputs) + self.ica_topo = TopoVizICA( + self.app, + montage, + self.ica, + self.ic_types, + show_sensors=True, + refresh_inputs=refresh_inputs, + ) - self.ica_visualizer.new_annot_desc = 'bad_manual' - self.eeg_visualizer.new_annot_desc = 'bad_manual' + self.ica_visualizer.new_annot_desc = "bad_manual" + self.eeg_visualizer.new_annot_desc = "bad_manual" self.ica_visualizer.update_layout() def update_bad_ics(self, annotator="manual"): """Add IC name to raw.info['bads'] after selection by user in app.""" - df = self.pipeline.flags['ic'].data_frame + df = self.pipeline.flags["ic"].data_frame manual_labels_df = pd.DataFrame( dict( - component=self.raw_ica.info['bads'], - annotater=[annotator] * len(self.raw_ica.info['bads']), - ic_type=["manual"] * len(self.raw_ica.info['bads']), - confidence=[1.0] * len(self.raw_ica.info['bads']) - ) + component=self.raw_ica.info["bads"], + annotater=[annotator] * len(self.raw_ica.info["bads"]), + ic_type=["manual"] * len(self.raw_ica.info["bads"]), + confidence=[1.0] * len(self.raw_ica.info["bads"]), + ) ) df = pd.concat((df[df.annotator != annotator], manual_labels_df)) - self.pipeline.flags['ic'].data_frame = df + self.pipeline.flags["ic"].data_frame = df def set_layout(self, disable_buttons=False): """Create the app.layout for the app object. @@ -169,58 +176,73 @@ def set_layout(self, disable_buttons=False): # Layout for file control row # derivatives_dir = self.project_root / 'derivatives' files_list = [] - files_list = [{'label': str(file.name), 'value': str(file)} - for file - in sorted(self.project_root.rglob("*.edf"))] - - dropdown_text = f'current folder: {self.project_root.resolve()}' - logo_fpath = '../assets/logo.png' - folder_button = dbc.Button('Folder', id='folder-selector', - color='primary', - outline=True, - className=CSS['button'], - title=dropdown_text, - disabled=disable_buttons) - save_button = dbc.Button('Save', id='save-button', - color='info', - outline=True, - className=CSS['button'], - disabled=disable_buttons) - self.drop_down = dcc.Dropdown(id='file-dropdown', - className=CSS['dropdown'], - placeholder="Select a file", - options=files_list) - control_header_row = dbc.Row([ - dbc.Col([folder_button, save_button], - width={'size': 2}), - dbc.Col([self.drop_down, - html.P(id='dropdown-output')], - width={'size': 6}), - dbc.Col( - html.Img(src=logo_fpath, - height='40px', - className=CSS['logo']), - width={'size': 2, 'offset': 2}), - ], - className=CSS['file-row'], - align='center', - ) + files_list = [ + {"label": str(file.name), "value": str(file)} + for file in sorted(self.project_root.rglob("*.edf")) + ] + + dropdown_text = f"current folder: {self.project_root.resolve()}" + logo_fpath = "../assets/logo.png" + folder_button = dbc.Button( + "Folder", + id="folder-selector", + color="primary", + outline=True, + className=CSS["button"], + title=dropdown_text, + disabled=disable_buttons, + ) + save_button = dbc.Button( + "Save", + id="save-button", + color="info", + outline=True, + className=CSS["button"], + disabled=disable_buttons, + ) + self.drop_down = dcc.Dropdown( + id="file-dropdown", + className=CSS["dropdown"], + placeholder="Select a file", + options=files_list, + ) + control_header_row = dbc.Row( + [ + dbc.Col([folder_button, save_button], width={"size": 2}), + dbc.Col( + [self.drop_down, html.P(id="dropdown-output")], width={"size": 6} + ), + dbc.Col( + html.Img(src=logo_fpath, height="40px", className=CSS["logo"]), + width={"size": 2, "offset": 2}, + ), + ], + className=CSS["file-row"], + align="center", + ) # Layout for EEG/ICA and Topo plots row - timeseries_div = html.Div([self.eeg_visualizer.container_plot, - self.ica_visualizer.container_plot], - id='channel-and-icsources-div', - className=CSS['timeseries-col']) - visualizers_row = dbc.Row([dbc.Col([timeseries_div], width=8), - dbc.Col(self.ica_topo.container_plot, - className=CSS['topo-col'], - width=4)], - style=STYLE['plots-row'], - className=CSS['plots-row'] - ) + timeseries_div = html.Div( + [self.eeg_visualizer.container_plot, self.ica_visualizer.container_plot], + id="channel-and-icsources-div", + className=CSS["timeseries-col"], + ) + visualizers_row = dbc.Row( + [ + dbc.Col([timeseries_div], width=8), + dbc.Col( + self.ica_topo.container_plot, className=CSS["topo-col"], width=4 + ), + ], + style=STYLE["plots-row"], + className=CSS["plots-row"], + ) # Final Layout - qc_app_layout = dbc.Container([control_header_row, visualizers_row], - fluid=True, style=STYLE['qc-container']) + qc_app_layout = dbc.Container( + [control_header_row, visualizers_row], + fluid=True, + style=STYLE["qc-container"], + ) self.app.layout.children.append(qc_app_layout) def load_recording(self, fpath, verbose=False): @@ -240,43 +262,46 @@ def load_recording(self, fpath, verbose=False): self.raw = self.pipeline.raw self.ica = self.pipeline.ica2 if self.raw: - info = mne.create_info(self.ica._ica_names, - sfreq=self.raw.info['sfreq'], - ch_types=['eeg'] * self.ica.n_components_, - verbose=verbose) + info = mne.create_info( + self.ica._ica_names, + sfreq=self.raw.info["sfreq"], + ch_types=["eeg"] * self.ica.n_components_, + verbose=verbose, + ) sources = self.ica.get_sources(self.raw).get_data() self.raw_ica = mne.io.RawArray(sources, info, verbose=verbose) - self.raw_ica.set_meas_date(self.raw.info['meas_date']) + self.raw_ica.set_meas_date(self.raw.info["meas_date"]) self.raw_ica.set_annotations(self.raw.annotations) - df = self.pipeline.flags['ic'].data_frame + df = self.pipeline.flags["ic"].data_frame - bads = [ic_name - for ic_name, ic_type - in df[["component", "ic_type"]].values - if ic_type == "manual"] + bads = [ + ic_name + for ic_name, ic_type in df[["component", "ic_type"]].values + if ic_type == "manual" + ] self.raw_ica.info["bads"] = bads else: self.raw_ica = None - df = self.pipeline.flags['ic'].data_frame + df = self.pipeline.flags["ic"].data_frame self.ic_types = df[df.annotator == "ic_label"] - self.ic_types = self.ic_types.set_index('component')['ic_type'] + self.ic_types = self.ic_types.set_index("component")["ic_type"] self.ic_types = self.ic_types.to_dict() - cmap = {ic: ic_label_cmap[ic_type] - for ic, ic_type in self.ic_types.items()} - self.ica_visualizer.load_recording(self.raw_ica, cmap=cmap, - ic_types=self.ic_types) + cmap = {ic: ic_label_cmap[ic_type] for ic, ic_type in self.ic_types.items()} + self.ica_visualizer.load_recording( + self.raw_ica, cmap=cmap, ic_types=self.ic_types + ) self.eeg_visualizer.load_recording(self.raw) - self.ica_topo.load_recording(self.raw.get_montage(), - self.ica, self.ic_types) + self.ica_topo.load_recording(self.raw.get_montage(), self.ica, self.ic_types) def set_callbacks(self): """Define additional callbacks that will be used by the qcr app.""" + # TODO: delete this folder selection callback @self.app.callback( - Output('file-dropdown', 'options'), - Input('folder-selector', 'n_clicks'), - prevent_initial_call=True + Output("file-dropdown", "options"), + Input("folder-selector", "n_clicks"), + prevent_initial_call=True, ) def folder_button_clicked(n_clicks): if n_clicks: @@ -284,17 +309,18 @@ def folder_button_clicked(n_clicks): folder_path = pool.apply(open_folder_dialog) self.project_root = Path(folder_path) - files_list = [{'label': str(file.name), 'value': str(file)} - for file - in sorted(self.project_root.rglob("*.edf"))] + files_list = [ + {"label": str(file.name), "value": str(file)} + for file in sorted(self.project_root.rglob("*.edf")) + ] return files_list return dash.no_update @self.app.callback( - Output('file-dropdown', 'placeholder'), - Output(self.eeg_visualizer.dash_ids['loading-output'], "children"), - Input('file-dropdown', 'value'), - prevent_initial_call=False + Output("file-dropdown", "placeholder"), + Output(self.eeg_visualizer.dash_ids["loading-output"], "children"), + Input("file-dropdown", "value"), + prevent_initial_call=False, ) def file_selected(value): if value: # on selection of dropdown item @@ -304,33 +330,32 @@ def file_selected(value): if self.fpath: self.load_recording(self.fpath) return str(self.fpath.name), [] - return '', [] + return "", [] @self.app.callback( - Output('dropdown-output', 'children'), - Input('save-button', 'n_clicks'), - prevent_initial_call=True + Output("dropdown-output", "children"), + Input("save-button", "n_clicks"), + prevent_initial_call=True, ) def save_file(n_clicks): self.update_bad_ics() self.eeg_visualizer.update_inst_annnotations() - self.pipeline.save(get_bids_path_from_fname(self.fpath), - overwrite=True) - logger.info('file saved!') + self.pipeline.save(get_bids_path_from_fname(self.fpath), overwrite=True) + logger.info("file saved!") return dash.no_update properties = ["value", "min", "max", "marks"] - slider_ids = [self.eeg_visualizer.dash_ids['time-slider'], - self.ica_visualizer.dash_ids['time-slider']] - sliders = [self.ica_visualizer.time_slider, - self.eeg_visualizer.time_slider] + slider_ids = [ + self.eeg_visualizer.dash_ids["time-slider"], + self.ica_visualizer.dash_ids["time-slider"], + ] + sliders = [self.ica_visualizer.time_slider, self.eeg_visualizer.time_slider] decorator_args = [] for slider_id in slider_ids: - decorator_args += [Output(slider_id, property) - for property in properties] + decorator_args += [Output(slider_id, property) for property in properties] for slider_id in slider_ids: - decorator_args += [Input(slider_id, 'value')] - decorator_args += [Input('file-dropdown', 'placeholder')] + decorator_args += [Input(slider_id, "value")] + decorator_args += [Input("file-dropdown", "placeholder")] @self.app.callback(*decorator_args, prevent_initial_call=True) def sync_time_sliders(*args): @@ -346,6 +371,7 @@ def sync_time_sliders(*args): The file-dropdown dash component. The placeholder component property of this dash-component is used to refresh the time sliders when loading a new file. + Returns ------- For the following component_properties of the EEG time slider and @@ -371,54 +397,61 @@ def sync_time_sliders(*args): if ctx.triggered: trigger_id = ctx.triggered[0]["prop_id"].split(".")[0] if trigger_id == slider_ids[0]: - return ([no_update]*len(properties) + [args[0]] + - [no_update]*(len(properties)-1)) + return ( + [no_update] * len(properties) + + [args[0]] + + [no_update] * (len(properties) - 1) + ) if trigger_id == slider_ids[1]: - return [args[1]] + [no_update]*(len(properties)*2-1) - if trigger_id == 'file-dropdown': + return [args[1]] + [no_update] * (len(properties) * 2 - 1) + if trigger_id == "file-dropdown": args = [] for slider in sliders[::-1]: - args += [getattr(slider, property) - for property in properties] + args += [getattr(slider, property) for property in properties] return args @self.app.callback( - Output(self.ica_visualizer.dash_ids['ch-slider'], - 'value'), - Output('topo-slider', 'value'), - Output(self.ica_visualizer.dash_ids['ch-slider'], - component_property='min'), - Output(self.ica_visualizer.dash_ids['ch-slider'], - component_property='max'), - Output('topo-slider', - component_property='min'), - Output('topo-slider', - component_property='max'), - Input(self.ica_visualizer.dash_ids['ch-slider'], - 'value'), - Input('topo-slider', 'value'), - Input('file-dropdown', 'placeholder'), - prevent_initial_call=True) + Output(self.ica_visualizer.dash_ids["ch-slider"], "value"), + Output("topo-slider", "value"), + Output(self.ica_visualizer.dash_ids["ch-slider"], component_property="min"), + Output(self.ica_visualizer.dash_ids["ch-slider"], component_property="max"), + Output("topo-slider", component_property="min"), + Output("topo-slider", component_property="max"), + Input(self.ica_visualizer.dash_ids["ch-slider"], "value"), + Input("topo-slider", "value"), + Input("file-dropdown", "placeholder"), + prevent_initial_call=True, + ) def sync_ica_sliders(ica_ch_slider, ica_topo_slider, selected_file): """Sync ICA-Raw and ICA-topo sliders and refresh upon new file.""" ctx = dash.callback_context if ctx.triggered: trigger_id = ctx.triggered[0]["prop_id"].split(".")[0] - if trigger_id == self.ica_visualizer.dash_ids['ch-slider']: + if trigger_id == self.ica_visualizer.dash_ids["ch-slider"]: # If user dragged the ica-raw ch-slider value = ica_ch_slider # only update the ica-topo slider val. - return (no_update, value, - no_update, no_update, # min max 4 ica-raw slider. - no_update, no_update) # min max 4 topo-slider - if trigger_id == 'topo-slider': + return ( + no_update, + value, + no_update, + no_update, # min max 4 ica-raw slider. + no_update, + no_update, + ) # min max 4 topo-slider + if trigger_id == "topo-slider": # If the user dragged the topoplot slider value = ica_topo_slider # only update the ica-raw ch-slider val - return (value, no_update, - no_update, no_update, # min max 4 ica-raw slider - no_update, no_update) # min max 4 topo-slider - if trigger_id == 'file-dropdown': + return ( + value, + no_update, + no_update, + no_update, # min max 4 ica-raw slider + no_update, + no_update, + ) # min max 4 topo-slider + if trigger_id == "file-dropdown": # If the user selected a new file value = self.ica_visualizer.channel_slider.value min_ = self.ica_visualizer.channel_slider.min @@ -428,20 +461,18 @@ def sync_ica_sliders(ica_ch_slider, ica_topo_slider, selected_file): return value, value, min_, max_, min_, max_ @self.app.callback( - Output(self.eeg_visualizer.dash_ids['ch-slider'], - 'value'), - Output(self.eeg_visualizer.dash_ids['ch-slider'], - component_property='min'), - Output(self.eeg_visualizer.dash_ids['ch-slider'], - component_property='max'), - Input('file-dropdown', 'placeholder'), - prevent_initial_call=True) + Output(self.eeg_visualizer.dash_ids["ch-slider"], "value"), + Output(self.eeg_visualizer.dash_ids["ch-slider"], component_property="min"), + Output(self.eeg_visualizer.dash_ids["ch-slider"], component_property="max"), + Input("file-dropdown", "placeholder"), + prevent_initial_call=True, + ) def refresh_eeg_ch_slider(selected_file): """Refresh eeg graph ch-slider upon new file selection.""" ctx = dash.callback_context if ctx.triggered: trigger_id = ctx.triggered[0]["prop_id"].split(".")[0] - if trigger_id == 'file-dropdown': + if trigger_id == "file-dropdown": # If the user selected a new file value = self.eeg_visualizer.channel_slider.value min_ = self.eeg_visualizer.channel_slider.min diff --git a/pylossless/dash/tests/conftest.py b/pylossless/dash/tests/conftest.py index 1ee0a97..282494a 100644 --- a/pylossless/dash/tests/conftest.py +++ b/pylossless/dash/tests/conftest.py @@ -13,6 +13,6 @@ def pytest_setup_options(): """Configure the dash tests with the chrom webdrive for CI.""" options = Options() - options.add_argument('--headless') - options.add_argument('--disable-gpu') + options.add_argument("--headless") + options.add_argument("--disable-gpu") return options diff --git a/pylossless/dash/tests/test_topo_viz.py b/pylossless/dash/tests/test_topo_viz.py index 4f96f6e..7cce43b 100644 --- a/pylossless/dash/tests/test_topo_viz.py +++ b/pylossless/dash/tests/test_topo_viz.py @@ -2,7 +2,7 @@ # Scott Huberty # License: MIT -"""Tests for topo_viz.py""" +"""Tests for topo_viz.py.""" import mne from dash import html @@ -11,6 +11,7 @@ def get_raw_ica(): + """Get raw and ICA object.""" data_path = mne.datasets.sample.data_path() raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" raw = mne.io.read_raw_fif(raw_fname) @@ -32,12 +33,14 @@ def get_raw_ica(): def test_TopoPlot(): + """Test plotting topoplots with plotly.""" raw, ica = get_raw_ica() data = dict(zip(ica.ch_names, ica.get_components()[:, 0])) TopoPlot(raw.get_montage(), data, res=200).figure def test_GridTopoPlot(): + """Test plotting grid of topoplots with plotly.""" raw, ica = get_raw_ica() topo_data = TopoData() @@ -46,28 +49,33 @@ def test_GridTopoPlot(): offset = 2 nb_topo = 4 - plot_data = topo_data.topo_values.iloc[::-1].iloc[offset:offset+nb_topo] + plot_data = topo_data.topo_values.iloc[::-1].iloc[offset : offset + nb_topo] plot_data = list(plot_data.T.to_dict().values()) - GridTopoPlot(2, 2, raw.get_montage(), plot_data, - res=200, width=300, height=300, - subplots_kwargs=dict(subplot_titles=[1, 2, 3, 4], - vertical_spacing=0.05)).figure + GridTopoPlot( + 2, + 2, + raw.get_montage(), + plot_data, + res=200, + width=300, + height=300, + subplots_kwargs=dict(subplot_titles=[1, 2, 3, 4], vertical_spacing=0.05), + ).figure # chromedriver: https://chromedriver.storage.googleapis.com/ # index.html?path=114.0.5735.90/ def test_TopoViz(dash_duo): + """Test TopoViz.""" raw, ica = get_raw_ica() topo_data = TopoData() for comp in ica.get_components().T: topo_data.add_topomap(dict(zip(ica.ch_names, comp))) - topo_viz = TopoViz(data=topo_data, montage=raw.get_montage(), - mode="standalone") + topo_viz = TopoViz(data=topo_data, montage=raw.get_montage(), mode="standalone") - topo_viz.app.layout.children.append(html.Div(id="nully-wrapper", - children=0)) + topo_viz.app.layout.children.append(html.Div(id="nully-wrapper", children=0)) dash_duo.start_server(topo_viz.app) assert dash_duo.find_element("#nully-wrapper").text == "0" diff --git a/pylossless/dash/topo_viz.py b/pylossless/dash/topo_viz.py index d89eb4a..50b4991 100644 --- a/pylossless/dash/topo_viz.py +++ b/pylossless/dash/topo_viz.py @@ -28,7 +28,7 @@ from . import ic_label_cmap # thin lines in the background and numbers below -axis = {'showgrid': False, 'visible': False} +axis = {"showgrid": False, "visible": False} yaxis = copy(axis) yaxis.update({"scaleanchor": "x", "scaleratio": 1}) @@ -36,10 +36,21 @@ class TopoPlot: # TODO: Fix/finish doc comments for this class. """Representation of a classic EEG topographic map as a plotly figure.""" - def __init__(self, montage="standard_1020", data=None, figure=None, - color="black", row=None, col=None, res=64, width=None, - height=None, cmap='RdBu_r', show_sensors=True, - colorbar=False): + def __init__( + self, + montage="standard_1020", + data=None, + figure=None, + color="black", + row=None, + col=None, + res=64, + width=None, + height=None, + cmap="RdBu_r", + show_sensors=True, + colorbar=False, + ): """Initialize instance. Parameters @@ -151,8 +162,9 @@ def set_data(self, data): self.info = create_info(names, sfreq=256, ch_types="eeg") with warnings.catch_warnings(): warnings.simplefilter("ignore") - RawArray(np.zeros((len(names), 1)), self.info, copy=None, - verbose=False).set_montage(self.montage) + RawArray( + np.zeros((len(names), 1)), self.info, copy=None, verbose=False + ).set_montage(self.montage) self.set_head_pos_contours() # TODO: Finish/fix docstring @@ -161,29 +173,33 @@ def set_head_pos_contours(self, sphere=None, picks=None): if not self.info: return sphere = _check_sphere(sphere, self.info) - self.pos, self.outlines = _get_pos_outlines(self.info, picks, sphere, - to_sphere=True) + self.pos, self.outlines = _get_pos_outlines( + self.info, picks, sphere, to_sphere=True + ) # TODO: Finish/fix docstring def get_head_scatters(self, color="back", show_sensors=True): """Build scatter plot from head position data.""" - outlines_scat = [go.Scatter(x=x, y=y, line=dict(color=color), - mode='lines', showlegend=False) - for key, (x, y) in self.outlines.items() - if 'clip' not in key and "mask" not in key] + outlines_scat = [ + go.Scatter(x=x, y=y, line=dict(color=color), mode="lines", showlegend=False) + for key, (x, y) in self.outlines.items() + if "clip" not in key and "mask" not in key + ] if show_sensors: - pos_scat = go.Scatter(x=self.pos.T[0], y=self.pos.T[1], - line=dict(color=color), mode='markers', - marker=dict(color=color, - size=2, - opacity=.5), - showlegend=False) + pos_scat = go.Scatter( + x=self.pos.T[0], + y=self.pos.T[1], + line=dict(color=color), + mode="markers", + marker=dict(color=color, size=2, opacity=0.5), + showlegend=False, + ) return outlines_scat + [pos_scat] return outlines_scat - def get_heatmap_data(self, ch_type="eeg", extrapolate='auto'): + def get_heatmap_data(self, ch_type="eeg", extrapolate="auto"): """Get the data to use for the topo plots. Parameters @@ -203,42 +219,52 @@ def get_heatmap_data(self, ch_type="eeg", extrapolate='auto'): """ extrapolate = _check_extrapolate(extrapolate, ch_type) # find mask limits and setup interpolation - _, Xi, Yi, interp = _setup_interp(self.pos, res=self.res, - image_interp="cubic", - extrapolate=extrapolate, - outlines=self.outlines, - border='mean') + _, Xi, Yi, interp = _setup_interp( + self.pos, + res=self.res, + image_interp="cubic", + extrapolate=extrapolate, + outlines=self.outlines, + border="mean", + ) interp.set_values(np.array(list(self.__data.values()))) Zi = interp.set_locations(Xi, Yi)() # Clip to the outer circler x0, y0 = self.outlines["clip_origin"] x_rad, y_rad = self.outlines["clip_radius"] - Zi[np.sqrt(((Xi - x0)/x_rad)**2 + ((Yi-y0)/y_rad)**2) > 1] = np.nan + Zi[np.sqrt(((Xi - x0) / x_rad) ** 2 + ((Yi - y0) / y_rad) ** 2) > 1] = np.nan return {"x": Xi[0], "y": Yi[:, 0], "z": Zi} def _update_axes(self): - self.figure.update_xaxes({'showgrid': False, 'visible': False}, - row=self.row, col=self.col) + self.figure.update_xaxes( + {"showgrid": False, "visible": False}, row=self.row, col=self.col + ) - scale_anchor = list(self.figure.select_yaxes(row=self.row, - col=self.col)) + scale_anchor = list(self.figure.select_yaxes(row=self.row, col=self.col)) scale_anchor = scale_anchor[0]["anchor"] if not scale_anchor: scale_anchor = "x" - self.figure.update_yaxes({'showgrid': False, 'visible': False, - "scaleanchor": scale_anchor, - "scaleratio": 1}, - row=self.row, col=self.col) + self.figure.update_yaxes( + { + "showgrid": False, + "visible": False, + "scaleanchor": scale_anchor, + "scaleratio": 1, + }, + row=self.row, + col=self.col, + ) self.figure.update_layout( - autosize=False, - width=self.width, - height=self.height, - plot_bgcolor='rgba(0,0,0,0)', - paper_bgcolor='rgba(0,0,0,0)', - margin=dict(l=0, r=0, b=0, t=20)) + autosize=False, + width=self.width, + height=self.height, + plot_bgcolor="rgba(0,0,0,0)", + paper_bgcolor="rgba(0,0,0,0)", + margin=dict(l=0, r=0, b=0, t=20), + ) def plot_topo(self, **kwargs): """Plot the topomap. @@ -255,9 +281,11 @@ def plot_topo(self, **kwargs): if self.__data is None: return - heatmap_trace = go.Heatmap(showscale=self.colorbar, - colorscale=self.cmap, - **self.get_heatmap_data(**kwargs)) + heatmap_trace = go.Heatmap( + showscale=self.colorbar, + colorscale=self.cmap, + **self.get_heatmap_data(**kwargs) + ) for trace in self.get_head_scatters(color=self.color): self.figure.add_trace(trace, row=self.row, col=self.col) @@ -270,24 +298,33 @@ def plot_topo(self, **kwargs): def __check_shape__(rows, cols, data, fill=None): if not isinstance(data, (list, tuple, np.ndarray)): - return np.array([[data]*cols]*rows) + return np.array([[data] * cols] * rows) data = np.array(data) if data.shape == (rows, cols): return data - if len(data.ravel()) < rows*cols: - data = np.concatenate((data.ravel(), - [fill]*(rows*cols-len(data.ravel())))) + if len(data.ravel()) < rows * cols: + data = np.concatenate( + (data.ravel(), [fill] * (rows * cols - len(data.ravel()))) + ) return data.reshape((rows, cols)) class GridTopoPlot: """Representation of grid of topomaps as a plotly figure.""" - def __init__(self, rows=1, cols=1, montage="standard_1020", - data=None, figure=None, color="black", - subplots_kwargs=None, **kwargs): + def __init__( + self, + rows=1, + cols=1, + montage="standard_1020", + data=None, + figure=None, + color="black", + subplots_kwargs=None, + **kwargs + ): """Initialize instance. Parameters @@ -321,8 +358,7 @@ def __init__(self, rows=1, cols=1, montage="standard_1020", montage = __check_shape__(rows, cols, montage) color = __check_shape__(rows, cols, color) - subplots_kwargs_ = dict(horizontal_spacing=0.03, - vertical_spacing=0.03) + subplots_kwargs_ = dict(horizontal_spacing=0.03, vertical_spacing=0.03) if subplots_kwargs: subplots_kwargs_.update(subplots_kwargs) self.rows = rows @@ -336,25 +372,36 @@ def __init__(self, rows=1, cols=1, montage="standard_1020", self.__data = __check_shape__(rows, cols, data) if figure is None: - self.figure = make_subplots(rows=rows, cols=cols, - **subplots_kwargs_) + self.figure = make_subplots(rows=rows, cols=cols, **subplots_kwargs_) else: self.figure = figure - self.topos = np.array([[TopoPlot(montage=m, data=d, - figure=self.figure, col=col+1, - row=row+1, color=color, **kwargs) - for col, (m, d, color) - in enumerate(zip(montage_row, data_row, - color_row))] - for row, (montage_row, data_row, color_row) - in enumerate(zip(montage, self.__data, - self.color))]) + self.topos = np.array( + [ + [ + TopoPlot( + montage=m, + data=d, + figure=self.figure, + col=col + 1, + row=row + 1, + color=color, + **kwargs + ) + for col, (m, d, color) in enumerate( + zip(montage_row, data_row, color_row) + ) + ] + for row, (montage_row, data_row, color_row) in enumerate( + zip(montage, self.__data, self.color) + ) + ] + ) @property def nb_topo(self): """Return the number of topoplots.""" - return self.rows*self.cols + return self.rows * self.cols class TopoData: # TODO: Fix/finish doc comments for this class. @@ -368,8 +415,9 @@ def add_topomap(self, topomap: dict, title=None): """topomap: dict.""" if not title: title = str(len(self.topo_values)) - self.topo_values = pd.concat([self.topo_values, - pd.DataFrame(topomap, index=[title])]) + self.topo_values = pd.concat( + [self.topo_values, pd.DataFrame(topomap, index=[title])] + ) @property def nb_topo(self): @@ -380,11 +428,25 @@ def nb_topo(self): class TopoViz: # TODO: Fix/finish doc comments for this class. """Representation of a classic EEG topographic map.""" - def __init__(self, app=None, montage=None, data=None, rows=5, cols=4, - width=400, height=600, margin_x=4/5, margin_y=2/5, res=64, - head_contours_color="black", - cmap='RdBu_r', show_sensors=True, mode=None, - show_slider=True, refresh_inputs=None): + def __init__( + self, + app=None, + montage=None, + data=None, + rows=5, + cols=4, + width=400, + height=600, + margin_x=4 / 5, + margin_y=2 / 5, + res=64, + head_contours_color="black", + cmap="RdBu_r", + show_sensors=True, + mode=None, + show_slider=True, + refresh_inputs=None, + ): """Initialize instance. Parameters @@ -470,24 +532,24 @@ def __init__(self, app=None, montage=None, data=None, rows=5, cols=4, stylesheets = [dbc.themes.SLATE] if mode == "standalone_jupyter": from jupyter_dash import JupyterDash - self.app = JupyterDash("TopoViz", - external_stylesheets=stylesheets) + + self.app = JupyterDash("TopoViz", external_stylesheets=stylesheets) self.mode = mode else: - self.app = dash.Dash("TopoViz", - external_stylesheets=stylesheets) + self.app = dash.Dash("TopoViz", external_stylesheets=stylesheets) self.mode = "standalone" self.app.layout = html.Div([]) else: self.app = app self.mode = "embedded" - self.graph = dcc.Graph(figure=None, id='topo-graph', - className=CSS['topo-dcc']) - self.graph_div = html.Div(children=[self.graph], - id='topo-graph-div', - className=CSS['topo-dcc-div'], - style=STYLE['topo-dcc-div']) + self.graph = dcc.Graph(figure=None, id="topo-graph", className=CSS["topo-dcc"]) + self.graph_div = html.Div( + children=[self.graph], + id="topo-graph-div", + className=CSS["topo-dcc-div"], + style=STYLE["topo-dcc-div"], + ) self._init_slider() self.set_data(montage, data, head_contours_color) @@ -541,8 +603,9 @@ def set_data(self, montage=None, data=None, head_contours_color="black"): if montage is not None: self.montage = montage if isinstance(head_contours_color, str): - head_contours_color = {title: head_contours_color - for title in self.data.topo_values.index} + head_contours_color = { + title: head_contours_color for title in self.data.topo_values.index + } if head_contours_color: self.head_contours_color = head_contours_color @@ -566,35 +629,40 @@ def initialize_layout(self, slider_val=None, show_sensors=True): titles = self.data.topo_values.index - last_sel_topo = self.offset+self.nb_sel_topo - titles = titles[::-1][self.offset:last_sel_topo][::-1] + last_sel_topo = self.offset + self.nb_sel_topo + titles = titles[::-1][self.offset : last_sel_topo][::-1] colors = [self.head_contours_color[title] for title in titles] # The indexing with ch_names is to ensure the order # of the channels are compatible between plot_data and the montage - ch_names = [ch_name for ch_name in self.montage.ch_names - if ch_name in self.data.topo_values.columns] + ch_names = [ + ch_name + for ch_name in self.montage.ch_names + if ch_name in self.data.topo_values.columns + ] plot_data = self.data.topo_values.loc[titles, ch_names] plot_data = list(plot_data.T.to_dict().values()) if len(plot_data) < self.nb_sel_topo: - nb_missing_topo = self.nb_sel_topo-len(plot_data) - plot_data = np.concatenate((plot_data, - [None]*nb_missing_topo)) - - self.figure = GridTopoPlot(rows=self.rows, cols=self.cols, - montage=self.montage, data=plot_data, - color=colors, - res=self.res, - height=self.height, - width=self.width, - show_sensors=show_sensors, - subplots_kwargs=dict( - horizontal_spacing=0.03, - vertical_spacing=0.03, - subplot_titles=titles, - ) - ).figure + nb_missing_topo = self.nb_sel_topo - len(plot_data) + plot_data = np.concatenate((plot_data, [None] * nb_missing_topo)) + + self.figure = GridTopoPlot( + rows=self.rows, + cols=self.cols, + montage=self.montage, + data=plot_data, + color=colors, + res=self.res, + height=self.height, + width=self.width, + show_sensors=show_sensors, + subplots_kwargs=dict( + horizontal_spacing=0.03, + vertical_spacing=0.03, + subplot_titles=titles, + ), + ).figure @property def nb_sel_topo(self): @@ -616,53 +684,59 @@ def nb_topo(self): def _init_slider(self): """Initialize the dcc.Slider component for the topoplots.""" - self.topo_slider = dcc.Slider(id='topo-slider', - min=self.nb_sel_topo - 1, - max=self.nb_topo - 1, - step=1, - marks=None, - value=self.nb_topo - 1, - included=False, - updatemode='mouseup', - vertical=True, - verticalHeight=400) - self.topo_slider_div = html.Div(self.topo_slider, - className=CSS['topo-slider-div'], - style={}) + self.topo_slider = dcc.Slider( + id="topo-slider", + min=self.nb_sel_topo - 1, + max=self.nb_topo - 1, + step=1, + marks=None, + value=self.nb_topo - 1, + included=False, + updatemode="mouseup", + vertical=True, + verticalHeight=400, + ) + self.topo_slider_div = html.Div( + self.topo_slider, className=CSS["topo-slider-div"], style={} + ) if not self.show_slider: - self.topo_slider_div.style.update({'display': 'none'}) + self.topo_slider_div.style.update({"display": "none"}) def _set_div(self): """Set the html.Div component for the topoplots.""" # outer_div includes slider obj graph_components = [self.topo_slider_div, self.graph_div] - self.container_plot = html.Div(children=graph_components, - id="ica-topo-div", - className=CSS['topo-container'], - style={'display': 'none'}) + self.container_plot = html.Div( + children=graph_components, + id="ica-topo-div", + className=CSS["topo-container"], + style={"display": "none"}, + ) def _set_callback(self): """Create the callback for the dcc.graph component of the topoplots.""" - args = [Output('topo-graph', 'figure')] - args += [Input('topo-slider', 'value')] + args = [Output("topo-graph", "figure")] + args += [Input("topo-slider", "value")] if self.refresh_inputs: args += self.refresh_inputs @self.app.callback(*args, suppress_callback_exceptions=False) def callback(slider_val, *args): - self.initialize_layout(slider_val=slider_val, - show_sensors=self.show_sensors) + self.initialize_layout( + slider_val=slider_val, show_sensors=self.show_sensors + ) if self.figure: return self.figure return dash.no_update - @self.app.callback(Output('ica-topo-div', 'style'), - Input('topo-graph', 'figure'), - ) + @self.app.callback( + Output("ica-topo-div", "style"), + Input("topo-graph", "figure"), + ) def show_figure(figure): if figure: - return {'display': 'block'} - return {'display': 'none'} + return {"display": "block"} + return {"display": "none"} class TopoVizICA(TopoViz): @@ -732,12 +806,16 @@ def init_vars(self, montage, ica, ic_labels): if not montage or not ica: return None - data = TopoData([dict(zip(montage.ch_names, component)) - for component in ica.get_components().T]) + data = TopoData( + [ + dict(zip(montage.ch_names, component)) + for component in ica.get_components().T + ] + ) if ic_labels: - self.head_contours_color = {comp: ic_label_cmap[label] - for comp, label - in ic_labels.items()} + self.head_contours_color = { + comp: ic_label_cmap[label] for comp, label in ic_labels.items() + } data.topo_values.index = list(ic_labels.keys()) return data diff --git a/pylossless/flagging.py b/pylossless/flagging.py index 12f1e23..cb364a0 100644 --- a/pylossless/flagging.py +++ b/pylossless/flagging.py @@ -67,7 +67,7 @@ def add_flag_cat(self, kind, bad_ch_names, *args): ------- None """ - logger.debug(f'NEW BAD CHANNELS {bad_ch_names}') + logger.debug(f"NEW BAD CHANNELS {bad_ch_names}") if isinstance(bad_ch_names, xr.DataArray): bad_ch_names = bad_ch_names.values self[kind] = bad_ch_names @@ -92,13 +92,11 @@ def rereference(self, inst, **kwargs): :meth:`mne.io.Raw.set_eeg_reference` method. """ # Concatenate and remove duplicates - bad_chs = list(set(self.ll.find_outlier_chs(inst) + - self.get_flagged() + - inst.info['bads'])) - ref_chans = [ch for ch in inst.copy().pick("eeg").ch_names - if ch not in bad_chs] - inst.set_eeg_reference(ref_channels=ref_chans, - **kwargs) + bad_chs = list( + set(self.ll.find_outlier_chs(inst) + self.get_flagged() + inst.info["bads"]) + ) + ref_chans = [ch for ch in inst.copy().pick("eeg").ch_names if ch not in bad_chs] + inst.set_eeg_reference(ref_channels=ref_chans, **kwargs) def save_tsv(self, fname): """Save flagged channel annotations to a text file. @@ -111,11 +109,11 @@ def save_tsv(self, fname): labels = [] ch_names = [] for key in self: - labels.extend([key]*len(self[key])) + labels.extend([key] * len(self[key])) ch_names.extend(self[key]) - pd.DataFrame({"labels": labels, - "ch_names": ch_names}).to_csv(fname, - index=False, sep="\t") + pd.DataFrame({"labels": labels, "ch_names": ch_names}).to_csv( + fname, index=False, sep="\t" + ) def load_tsv(self, fname): """Load serialized channel annotations. @@ -126,7 +124,7 @@ def load_tsv(self, fname): Filename of the tsv file with the annotation information to be loaded. """ - out_df = pd.read_csv(fname, sep='\t') + out_df = pd.read_csv(fname, sep="\t") for label, grp_df in out_df.groupby("labels"): self[label] = grp_df.ch_names.values @@ -187,15 +185,15 @@ def add_flag_cat(self, kind, bad_epoch_inds, epochs): def load_from_raw(self, raw): """Load ``'bad_pylossless'`` annotations from raw object.""" - sfreq = raw.info['sfreq'] + sfreq = raw.info["sfreq"] for annot in raw.annotations: - if annot['description'].startswith('bad_pylossless'): - ind_onset = int(np.round(annot['onset'] * sfreq)) - ind_dur = int(np.round(annot['duration'] * sfreq)) + if annot["description"].startswith("bad_pylossless"): + ind_onset = int(np.round(annot["onset"] * sfreq)) + ind_dur = int(np.round(annot["duration"] * sfreq)) inds = np.arange(ind_onset, ind_onset + ind_dur) - if annot['description'] not in self: - self[annot['description']] = list() - self[annot['description']].append(inds) + if annot["description"] not in self: + self[annot["description"]] = list() + self[annot["description"]].append(inds) class FlaggedICs(dict): @@ -269,12 +267,12 @@ def save_tsv(self, fname): The output filename. """ self.fname = fname - self.data_frame.to_csv(fname, sep='\t', index=False, na_rep='n/a') + self.data_frame.to_csv(fname, sep="\t", index=False, na_rep="n/a") # TODO: Add parameters. def load_tsv(self, fname, data_frame=None): """Load flagged ICs from file.""" self.fname = fname if data_frame is None: - data_frame = pd.read_csv(fname, sep='\t') + data_frame = pd.read_csv(fname, sep="\t") self.data_frame = data_frame diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index 507ebca..dc989d2 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -65,10 +65,10 @@ def epochs_to_xr(epochs, kind="ch", ica=None): else: raise ValueError("The argument kind must be equal to 'ch' or 'ic'.") - return xr.DataArray(data, - coords={'epoch': np.arange(data.shape[0]), - kind: names, - "time": epochs.times}) + return xr.DataArray( + data, + coords={"epoch": np.arange(data.shape[0]), kind: names, "time": epochs.times}, + ) def get_operate_dim(array, flag_dim): @@ -125,28 +125,32 @@ def _get_outliers_quantile(array, dim, lower=0.25, upper=0.75, mid=0.5, k=3): Vector of values (of size n_channels or n_epochs) to be considered the upper thresholds for outliers. """ - lower_val, mid_val, upper_val = array.quantile([lower, mid, upper], - dim=dim) + lower_val, mid_val, upper_val = array.quantile([lower, mid, upper], dim=dim) # Code below deviates from Tukeys method (Q2 +/- k(Q3-Q1)) # because we need to account for distribution skewness. lower_dist = mid_val - lower_val upper_dist = upper_val - mid_val - return mid_val - lower_dist*k, mid_val + upper_dist*k + return mid_val - lower_dist * k, mid_val + upper_dist * k def _get_outliers_trimmed(array, dim, trim=0.2, k=3): """Calculate outliers for Epochs or Channels based on the trimmed mean.""" - trim_mean = partial(scipy.stats.mstats.trimmed_mean, - limits=(trim, trim)) + trim_mean = partial(scipy.stats.mstats.trimmed_mean, limits=(trim, trim)) trim_std = partial(scipy.stats.mstats.trimmed_std, limits=(trim, trim)) m_dist = array.reduce(trim_mean, dim=dim) s_dist = array.reduce(trim_std, dim=dim) - return m_dist - s_dist*k, m_dist + s_dist*k + return m_dist - s_dist * k, m_dist + s_dist * k -def _detect_outliers(array, flag_dim='epoch', outlier_method='quantile', - flag_crit=0.2, init_dir='both', outliers_kwargs=None): +def _detect_outliers( + array, + flag_dim="epoch", + outlier_method="quantile", + flag_crit=0.2, + init_dir="both", + outliers_kwargs=None, +): """Mark epochs, channels, or ICs as flagged for artefact. Parameters @@ -173,6 +177,7 @@ def _detect_outliers(array, flag_dim='epoch', outlier_method='quantile', Set in the pipeline config. 'k', 'lower', and 'upper' kwargs can be passed to _get_outliers_quantile. 'k' can also be passed to _get_outliers_trimmed. + Returns ------- boolean xr.DataArray of shape n_epochs, n_times, where an epoch x channel @@ -185,37 +190,37 @@ def _detect_outliers(array, flag_dim='epoch', outlier_method='quantile', # Computing lower and upper bounds for outlier detection operate_dim = get_operate_dim(array, flag_dim) - if outlier_method == 'quantile': - l_out, u_out = _get_outliers_quantile(array, flag_dim, - **outliers_kwargs) + if outlier_method == "quantile": + l_out, u_out = _get_outliers_quantile(array, flag_dim, **outliers_kwargs) - elif outlier_method == 'trimmed': - l_out, u_out = _get_outliers_trimmed(array, flag_dim, - **outliers_kwargs) + elif outlier_method == "trimmed": + l_out, u_out = _get_outliers_trimmed(array, flag_dim, **outliers_kwargs) - elif outlier_method == 'fixed': + elif outlier_method == "fixed": l_out, u_out = outliers_kwargs["lower"], outliers_kwargs["upper"] else: - raise ValueError("outlier_method must be 'quantile', 'trimmed'" - f", or 'fixed'. Got {outlier_method}") + raise ValueError( + "outlier_method must be 'quantile', 'trimmed'" + f", or 'fixed'. Got {outlier_method}" + ) # Calculating the proportion of outliers along dimension operate_dim # and marking items along dimension flag_dim if this number is # larger than outlier_mask = xr.zeros_like(array, dtype=bool) - if init_dir == 'pos' or init_dir == 'both': # for positive outliers + if init_dir == "pos" or init_dir == "both": # for positive outliers outlier_mask = outlier_mask | (array > u_out) - if init_dir == 'neg' or init_dir == 'both': # for negative outliers + if init_dir == "neg" or init_dir == "both": # for negative outliers outlier_mask = outlier_mask | (array < l_out) # average column of outlier_mask # drop quantile coord because it is no longer needed prop_outliers = outlier_mask.astype(float).mean(operate_dim) if "quantile" in list(prop_outliers.coords.keys()): - prop_outliers = prop_outliers.drop_vars('quantile') + prop_outliers = prop_outliers.drop_vars("quantile") return prop_outliers[prop_outliers > flag_crit].coords.to_index().values @@ -238,24 +243,28 @@ def _threshold_volt_std(epochs, flag_dim, threshold=5e-5): if isinstance(threshold, (tuple, list)): assert len(threshold) == 2 l_out, u_out = threshold - init_dir = 'both' + init_dir = "both" elif isinstance(threshold, float): l_out, u_out = (0, threshold) - init_dir = 'pos' + init_dir = "pos" else: - raise ValueError('threshold must be an int, float, or a list/tuple' - f' of 2 int or float values. got {threshold}') + raise ValueError( + "threshold must be an int, float, or a list/tuple" + f" of 2 int or float values. got {threshold}" + ) epochs_xr = epochs_to_xr(epochs, kind="ch") data_sd = epochs_xr.std("time") # Flag channels or epochs if their std is above # a fixed threshold. outliers_kwargs = dict(lower=l_out, upper=u_out) - volt_outlier_inds = _detect_outliers(data_sd, - flag_dim=flag_dim, - outlier_method='fixed', - init_dir=init_dir, - outliers_kwargs=outliers_kwargs) + volt_outlier_inds = _detect_outliers( + data_sd, + flag_dim=flag_dim, + outlier_method="fixed", + init_dir=init_dir, + outliers_kwargs=outliers_kwargs, + ) return volt_outlier_inds @@ -278,58 +287,68 @@ def chan_neighbour_r(epochs, nneigbr, method): Xarray : Xarray.DataArray An instance of Xarray.DataArray """ - chan_locs = pd.DataFrame(epochs.get_montage().get_positions()['ch_pos']).T - chan_dist = pd.DataFrame(distance_matrix(chan_locs, chan_locs), - columns=chan_locs.index, - index=chan_locs.index) - rank = chan_dist.rank('columns', ascending=True) - 1 + chan_locs = pd.DataFrame(epochs.get_montage().get_positions()["ch_pos"]).T + chan_dist = pd.DataFrame( + distance_matrix(chan_locs, chan_locs), + columns=chan_locs.index, + index=chan_locs.index, + ) + rank = chan_dist.rank("columns", ascending=True) - 1 rank[rank == 0] = np.nan - nearest_neighbor = pd.DataFrame({ch_name: row.dropna() - .sort_values()[:nneigbr] - .index.values - for ch_name, row in rank.iterrows()}).T + nearest_neighbor = pd.DataFrame( + { + ch_name: row.dropna().sort_values()[:nneigbr].index.values + for ch_name, row in rank.iterrows() + } + ).T r_list = [] for name, row in tqdm(list(nearest_neighbor.iterrows())): this_ch = epochs.get_data(name) nearest_chs = epochs.get_data(list(row.values)) - this_ch_xr = xr.DataArray([this_ch * np.ones_like(nearest_chs)], - dims=['ref_chan', 'epoch', - 'channel', 'time'], - coords={'ref_chan': [name], - 'epoch': np.arange(len(epochs)), - 'channel': row.values.tolist(), - 'time': epochs.times - } - ) - nearest_chs_xr = xr.DataArray([nearest_chs], - dims=['ref_chan', 'epoch', - 'channel', 'time'], - coords={'ref_chan': [name], - 'epoch': np.arange(len(epochs)), - 'channel': row.values.tolist(), - 'time': epochs.times}) - r_list.append(xr.corr(this_ch_xr, nearest_chs_xr, dim=['time'])) - - c_neigbr_r = xr.concat(r_list, dim='ref_chan') - - if method == 'max': - m_neigbr_r = xr.apply_ufunc(np.abs, c_neigbr_r).max(dim='channel') - - elif method == 'mean': - m_neigbr_r = xr.apply_ufunc(np.abs, c_neigbr_r).mean(dim='channel') - - elif method == 'trimmean': + this_ch_xr = xr.DataArray( + [this_ch * np.ones_like(nearest_chs)], + dims=["ref_chan", "epoch", "channel", "time"], + coords={ + "ref_chan": [name], + "epoch": np.arange(len(epochs)), + "channel": row.values.tolist(), + "time": epochs.times, + }, + ) + nearest_chs_xr = xr.DataArray( + [nearest_chs], + dims=["ref_chan", "epoch", "channel", "time"], + coords={ + "ref_chan": [name], + "epoch": np.arange(len(epochs)), + "channel": row.values.tolist(), + "time": epochs.times, + }, + ) + r_list.append(xr.corr(this_ch_xr, nearest_chs_xr, dim=["time"])) + + c_neigbr_r = xr.concat(r_list, dim="ref_chan") + + if method == "max": + m_neigbr_r = xr.apply_ufunc(np.abs, c_neigbr_r).max(dim="channel") + + elif method == "mean": + m_neigbr_r = xr.apply_ufunc(np.abs, c_neigbr_r).mean(dim="channel") + + elif method == "trimmean": trim_mean_10 = partial(scipy.stats.trim_mean, proportiontocut=0.1) - m_neigbr_r = xr.apply_ufunc(np.abs, c_neigbr_r)\ - .reduce(trim_mean_10, dim='channel') + m_neigbr_r = xr.apply_ufunc(np.abs, c_neigbr_r).reduce( + trim_mean_10, dim="channel" + ) return m_neigbr_r.rename(ref_chan="ch") # TODO: check that annot type contains all unique flags -def marks_flag_gap(raw, min_gap_ms, included_annot_type=None, - out_annot_name='bad_pylossless_gap'): +def marks_flag_gap( + raw, min_gap_ms, included_annot_type=None, out_annot_name="bad_pylossless_gap" +): """Mark small gaps in time between pylossless annotations. Parameters @@ -353,34 +372,51 @@ def marks_flag_gap(raw, min_gap_ms, included_annot_type=None, An instance of `mne.Annotations` """ if included_annot_type is None: - included_annot_type = ('bad_pylossless_ch_sd', 'bad_pylossless_low_r', - 'bad_pylossless_ic_sd1', 'bad_pylossless_gap') + included_annot_type = ( + "bad_pylossless_ch_sd", + "bad_pylossless_low_r", + "bad_pylossless_ic_sd1", + "bad_pylossless_gap", + ) if len(raw.annotations) == 0: return mne.Annotations([], [], [], orig_time=raw.annotations.orig_time) - ret_val = np.array([[annot['onset'], annot['duration']] - for annot in raw.annotations - if annot['description'] in included_annot_type]).T + ret_val = np.array( + [ + [annot["onset"], annot["duration"]] + for annot in raw.annotations + if annot["description"] in included_annot_type + ] + ).T if len(ret_val) == 0: return mne.Annotations([], [], [], orig_time=raw.annotations.orig_time) onsets, durations = ret_val offsets = onsets + durations - gaps = np.array([min(onset - offsets[offsets < onset]) - if np.sum(offsets < onset) else np.inf - for onset in onsets[1:]]) + gaps = np.array( + [ + min(onset - offsets[offsets < onset]) if np.sum(offsets < onset) else np.inf + for onset in onsets[1:] + ] + ) gap_mask = gaps < min_gap_ms / 1000 - return mne.Annotations(onset=onsets[1:][gap_mask] - gaps[gap_mask], - duration=gaps[gap_mask], - description=out_annot_name, - orig_time=raw.annotations.orig_time) + return mne.Annotations( + onset=onsets[1:][gap_mask] - gaps[gap_mask], + duration=gaps[gap_mask], + description=out_annot_name, + orig_time=raw.annotations.orig_time, + ) -def coregister(raw_edf, fiducials="estimated", # get fiducials from fsaverage - show_coreg=False, verbose=False): +def coregister( + raw_edf, + fiducials="estimated", # get fiducials from fsaverage + show_coreg=False, + verbose=False, +): """Coregister Raw object to `'fsaverage'`. Parameters @@ -400,17 +436,18 @@ def coregister(raw_edf, fiducials="estimated", # get fiducials from fsaverage coregistration | numpy.array a numpy array containing the coregistration trans values. """ - plot_kwargs = dict(subject='fsaverage', - surfaces="head-dense", dig=True, show_axes=True) + plot_kwargs = dict( + subject="fsaverage", surfaces="head-dense", dig=True, show_axes=True + ) - coreg = Coregistration(raw_edf.info, 'fsaverage', fiducials=fiducials) + coreg = Coregistration(raw_edf.info, "fsaverage", fiducials=fiducials) coreg.fit_fiducials(verbose=verbose) - coreg.fit_icp(n_iterations=20, nasion_weight=10., verbose=verbose) + coreg.fit_icp(n_iterations=20, nasion_weight=10.0, verbose=verbose) if show_coreg: mne.viz.plot_alignment(raw_edf.info, trans=coreg.trans, **plot_kwargs) - return coreg.trans['trans'][:-1].ravel() + return coreg.trans["trans"][:-1].ravel() # Warp locations to standard head surface: @@ -426,8 +463,8 @@ def warp_locs(self, raw): ------- None (operates in place) """ - if 'montage_info' in self.config['replace_string']: - if isinstance(self.config['replace_string']['montage_info'], str): + if "montage_info" in self.config["replace_string"]: + if isinstance(self.config["replace_string"]["montage_info"], str): pass # TODO: if it is a BIDS channel tsv, load the tsv,sd_t_f_vals # else read the file that is assumed to be a transformation matrix. @@ -439,7 +476,7 @@ def warp_locs(self, raw): # MNE does not apply the transform to the montage permanently. -class LosslessPipeline(): +class LosslessPipeline: """Class used to handle pipeline parameters.""" def __init__(self, config_fname=None): @@ -451,9 +488,11 @@ def __init__(self, config_fname=None): path to config file specifying the parameters to be used in the pipeline. """ - self.flags = {"ch": FlaggedChs(self), - "epoch": FlaggedEpochs(self), - "ic": FlaggedICs()} + self.flags = { + "ch": FlaggedChs(self), + "epoch": FlaggedEpochs(self), + "ic": FlaggedICs(), + } self.config_fname = config_fname if config_fname: self.load_config() @@ -473,19 +512,21 @@ def _check_sfreq(self): at 0.98 when it should start at 1, which will result in 2 epochs being dropped the next time data are epoched. """ - sfreq = self.raw.info['sfreq'] + sfreq = self.raw.info["sfreq"] if not sfreq.is_integer(): # we can't use f-strings in the logging module - msg = ("The Raw sampling frequency is %.2f. a non-integer" - " sampling frequency can cause incorrect mapping of epochs " - "to annotations. downsampling to %d" % (sfreq, int(sfreq))) + msg = ( + "The Raw sampling frequency is %.2f. a non-integer" + " sampling frequency can cause incorrect mapping of epochs " + "to annotations. downsampling to %d" % (sfreq, int(sfreq)) + ) logger.warn(msg) self.raw.resample(int(sfreq)) return self.raw def set_montage(self): """Set the montage.""" - analysis_montage = self.config['project']['analysis_montage'] + analysis_montage = self.config["project"]["analysis_montage"] if analysis_montage == "" and self.raw.get_montage() is not None: # No analysis montage has been specified and raw already has # a montage. Nothing to do; just return. This can happen @@ -496,14 +537,15 @@ def set_montage(self): if analysis_montage in mne.channels.montage.get_builtin_montages(): # If chanlocs is a string of one the standard MNE montages montage = mne.channels.make_standard_montage(analysis_montage) - montage_kwargs = self.config['project']['set_montage_kwargs'] - self.raw.set_montage(montage, - **montage_kwargs) + montage_kwargs = self.config["project"]["set_montage_kwargs"] + self.raw.set_montage(montage, **montage_kwargs) else: # If the montage is a filepath of a custom montage - raise ValueError('self.config["project"]["analysis_montage"]' - ' should be one of the default MNE montages as' - ' specified by' - ' mne.channels.get_builtin_montages().') + raise ValueError( + 'self.config["project"]["analysis_montage"]' + " should be one of the default MNE montages as" + " specified by" + " mne.channels.get_builtin_montages()." + ) # montage = read_custom_montage(chan_locs) def add_pylossless_annotations(self, inds, event_type, epochs): @@ -519,29 +561,28 @@ def add_pylossless_annotations(self, inds, event_type, epochs): an instance of mne.Epochs """ # Concatenate epoched data back to continuous data - t_onset = epochs.events[inds, 0] / epochs.info['sfreq'] + t_onset = epochs.events[inds, 0] / epochs.info["sfreq"] # We exclude the last sample from the duration because # if the annot lasts the whole duration of the epoch # it's end will coincide with the first sample of the # next epoch, causing it to erroneously be rejected. - duration = (np.ones_like(t_onset) / - epochs.info['sfreq'] * len(epochs.times[:-1]) - ) - description = [f'bad_pylossless_{event_type}'] * len(t_onset) - annotations = mne.Annotations(t_onset, duration, description, - orig_time=self.raw.annotations.orig_time) + duration = np.ones_like(t_onset) / epochs.info["sfreq"] * len(epochs.times[:-1]) + description = [f"bad_pylossless_{event_type}"] * len(t_onset) + annotations = mne.Annotations( + t_onset, duration, description, orig_time=self.raw.annotations.orig_time + ) self.raw.set_annotations(self.raw.annotations + annotations) def get_events(self): """Make an MNE events array of fixed length events.""" - tmin = self.config['epoching']['epochs_args']['tmin'] - tmax = self.config['epoching']['epochs_args']['tmax'] - overlap = self.config['epoching']['overlap'] - return mne.make_fixed_length_events(self.raw, duration=tmax-tmin, - overlap=overlap) - - def get_epochs(self, detrend=None, preload=True, rereference=True, - picks='eeg'): + tmin = self.config["epoching"]["epochs_args"]["tmin"] + tmax = self.config["epoching"]["epochs_args"]["tmax"] + overlap = self.config["epoching"]["overlap"] + return mne.make_fixed_length_events( + self.raw, duration=tmax - tmin, overlap=overlap + ) + + def get_epochs(self, detrend=None, preload=True, rereference=True, picks="eeg"): """Create mne.Epochs according to user arguments. Parameters @@ -566,21 +607,18 @@ def get_epochs(self, detrend=None, preload=True, rereference=True, # TODO: automatically load detrend/preload description from MNE. logger.info("🧹 Epoching..") events = self.get_events() - epoching_kwargs = deepcopy(self.config['epoching']['epochs_args']) + epoching_kwargs = deepcopy(self.config["epoching"]["epochs_args"]) # MNE epoching is end-inclusive, causing an extra time # sample be included. This removes that extra sample: # https://github.com/mne-tools/mne-python/issues/6932 - epoching_kwargs["tmax"] -= 1 / self.raw.info['sfreq'] + epoching_kwargs["tmax"] -= 1 / self.raw.info["sfreq"] if detrend is not None: - epoching_kwargs['detrend'] = detrend - epochs = mne.Epochs(self.raw, events=events, - preload=preload, **epoching_kwargs) - epochs = (epochs.pick(picks=picks, exclude='bads') - .pick(picks=None, - exclude=list(self.flags["ch"].get_flagged()) - ) - ) + epoching_kwargs["detrend"] = detrend + epochs = mne.Epochs(self.raw, events=events, preload=preload, **epoching_kwargs) + epochs = epochs.pick(picks=picks, exclude="bads").pick( + picks=None, exclude=list(self.flags["ch"].get_flagged()) + ) if rereference: self.flags["ch"].rereference(epochs) @@ -589,8 +627,8 @@ def get_epochs(self, detrend=None, preload=True, rereference=True, def run_staging_script(self): """Run a staging script if specified in config.""" # TODO: - if 'staging_script' in self.config: - staging_script = Path(self.config['staging_script']) + if "staging_script" in self.config: + staging_script = Path(self.config["staging_script"]) if staging_script.exists(): exec(staging_script.open().read()) @@ -611,9 +649,9 @@ def find_breaks(self): for example would unpack two keyword arguments from mne.preprocessing.annotate_break """ - if 'find_breaks' not in self.config or not self.config['find_breaks']: + if "find_breaks" not in self.config or not self.config["find_breaks"]: return - breaks = annotate_break(self.raw, **self.config['find_breaks']) + breaks = annotate_break(self.raw, **self.config["find_breaks"]) self.raw.set_annotations(breaks + self.raw.annotations) def _flag_volt_std(self, flag_dim, threshold=5e-5): @@ -630,6 +668,7 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5): channel x epoch indices will considered an outlier. Defaults to 5e-5, or 50 microvolts. Note that here, 'time' refers to the samples in an epoch. + Notes ----- This method takes an array of shape n_channels x n_epochs x n_times @@ -649,10 +688,10 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5): on. You may need to assess a more appropriate value for your own data. """ epochs = self.get_epochs() - above_threshold = _threshold_volt_std(epochs, - flag_dim=flag_dim, - threshold=threshold) - self.flags[flag_dim].add_flag_cat('volt_std', above_threshold, epochs) + above_threshold = _threshold_volt_std( + epochs, flag_dim=flag_dim, threshold=threshold + ) + self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs) def find_outlier_chs(self, inst): """Detect outlier Channels to leave out of rereference.""" @@ -663,13 +702,14 @@ def find_outlier_chs(self, inst): elif isinstance(inst, mne.Raw): epochs = self.get_epochs(rereference=False) else: - raise TypeError('inst must be an MNE Raw or Epochs object,' - f' but got {type(inst)}.') + raise TypeError( + "inst must be an MNE Raw or Epochs object," f" but got {type(inst)}." + ) epochs_xr = epochs_to_xr(epochs, kind="ch") # Determines comically bad channels, # and leaves them out of average rereference - trim_ch_sd = epochs_xr.std('time') + trim_ch_sd = epochs_xr.std("time") # Measure how diff the std of 1 channel is with respect # to other channels (nonparametric z-score) ch_dist = trim_ch_sd - trim_ch_sd.median(dim="ch") @@ -684,7 +724,7 @@ def find_outlier_chs(self, inst): mdn = np.median(mean_ch_dist) deviation = np.diff(np.quantile(mean_ch_dist, [0.3, 0.7])) - return mean_ch_dist.ch[mean_ch_dist > mdn+6*deviation].values.tolist() + return mean_ch_dist.ch[mean_ch_dist > mdn + 6 * deviation].values.tolist() def flag_channels_fixed_threshold(self, threshold=5e-5): """Flag channels based on the stdev value across the time dimension. @@ -711,13 +751,11 @@ def flag_channels_fixed_threshold(self, threshold=5e-5): with. You may need to assess a more appropriate value for your own data. """ - if 'flag_channels_fixed_threshold' not in self.config: + if "flag_channels_fixed_threshold" not in self.config: return - if 'threshold' in self.config['flag_channels_fixed_threshold']: - threshold = (self.config['flag_channels_fixed_threshold'] - ['threshold'] - ) - self._flag_volt_std(flag_dim='ch', threshold=threshold) + if "threshold" in self.config["flag_channels_fixed_threshold"]: + threshold = self.config["flag_channels_fixed_threshold"]["threshold"] + self._flag_volt_std(flag_dim="ch", threshold=threshold) def flag_epochs_fixed_threshold(self, threshold=5e-5): """Flag epochs based on the stdev value across the time dimension. @@ -744,13 +782,11 @@ def flag_epochs_fixed_threshold(self, threshold=5e-5): with. You may need to assess a more appropriate value for your own data. """ - if 'flag_epochs_fixed_threshold' not in self.config: + if "flag_epochs_fixed_threshold" not in self.config: return - if 'threshold' in self.config['flag_epochs_fixed_threshold']: - threshold = (self.config['flag_epochs_fixed_threshold'] - ['threshold'] - ) - self._flag_volt_std(flag_dim='epoch', threshold=threshold) + if "threshold" in self.config["flag_epochs_fixed_threshold"]: + threshold = self.config["flag_epochs_fixed_threshold"]["threshold"] + self._flag_volt_std(flag_dim="epoch", threshold=threshold) @lossless_logger def flag_ch_sd_ch(self): @@ -771,47 +807,43 @@ def flag_ch_sd_ch(self): data_sd = epochs_xr.std("time") # flag channels for ch_sd - bad_ch_names = _detect_outliers(data_sd, flag_dim='ch', - init_dir='pos', - **self.config['ch_ch_sd']) - logger.info(f'📋 LOSSLESS: Noisy channels: {bad_ch_names}') + bad_ch_names = _detect_outliers( + data_sd, flag_dim="ch", init_dir="pos", **self.config["ch_ch_sd"] + ) + logger.info(f"📋 LOSSLESS: Noisy channels: {bad_ch_names}") - self.flags["ch"].add_flag_cat(kind='ch_sd', - bad_ch_names=bad_ch_names) + self.flags["ch"].add_flag_cat(kind="ch_sd", bad_ch_names=bad_ch_names) @lossless_logger def flag_ch_sd_epoch(self): """Flag epochs with outlying standard deviation.""" # TODO: flag "ch_sd" should be renamed "time_sd" - outlier_methods = ('quantile', 'trimmed', 'fixed') + outlier_methods = ("quantile", "trimmed", "fixed") epochs = self.get_epochs() epochs_xr = epochs_to_xr(epochs, kind="ch") data_sd = epochs_xr.std("time") # flag epochs for ch_sd - if 'epoch_ch_sd' in self.config: - config_epoch = self.config['epoch_ch_sd'] - if 'outlier_method' in config_epoch: - if config_epoch['outlier_method'] is None: - del config_epoch['outlier_method'] - elif config_epoch['outlier_method'] not in outlier_methods: + if "epoch_ch_sd" in self.config: + config_epoch = self.config["epoch_ch_sd"] + if "outlier_method" in config_epoch: + if config_epoch["outlier_method"] is None: + del config_epoch["outlier_method"] + elif config_epoch["outlier_method"] not in outlier_methods: raise NotImplementedError - bad_epoch_inds = _detect_outliers(data_sd, - flag_dim='epoch', - init_dir='pos', - **config_epoch) - logger.info(f'📋 LOSSLESS: Noisy epochs: {bad_epoch_inds}') - self.flags["epoch"].add_flag_cat('ch_sd', - bad_epoch_inds, - epochs) + bad_epoch_inds = _detect_outliers( + data_sd, flag_dim="epoch", init_dir="pos", **config_epoch + ) + logger.info(f"📋 LOSSLESS: Noisy epochs: {bad_epoch_inds}") + self.flags["epoch"].add_flag_cat("ch_sd", bad_epoch_inds, epochs) def get_n_nbr(self): """Calculate nearest neighbour correlation for channels.""" # Calculate nearest neighbour correlation on # non-flagged channels and epochs... epochs = self.get_epochs() - n_nbr_ch = self.config['nearest_neighbors']['n_nbr_ch'] - return chan_neighbour_r(epochs, n_nbr_ch, 'max'), epochs + n_nbr_ch = self.config["nearest_neighbors"]["n_nbr_ch"] + return chan_neighbour_r(epochs, n_nbr_ch, "max"), epochs @lossless_logger def flag_ch_low_r(self): @@ -827,12 +859,12 @@ def flag_ch_low_r(self): data_r_ch = self.get_n_nbr()[0] # Create the window criteria vector for flagging low_r chan_info... - bad_ch_names = _detect_outliers(data_r_ch, flag_dim='ch', - init_dir='neg', - **self.config['ch_low_r']) - logger.info(f'📋 LOSSLESS: Uncorrelated channels: {bad_ch_names}') + bad_ch_names = _detect_outliers( + data_r_ch, flag_dim="ch", init_dir="neg", **self.config["ch_low_r"] + ) + logger.info(f"📋 LOSSLESS: Uncorrelated channels: {bad_ch_names}") # Edit the channel flag info structure - self.flags["ch"].add_flag_cat(kind='low_r', bad_ch_names=bad_ch_names) + self.flags["ch"].add_flag_cat(kind="low_r", bad_ch_names=bad_ch_names) return data_r_ch @lossless_logger @@ -847,28 +879,24 @@ def flag_ch_bridge(self, data_r_ch): # Uses the correlation of neighbours # calculated to flag bridged channels. - msr = data_r_ch.median("epoch") / data_r_ch.reduce(scipy.stats.iqr, - dim="epoch") + msr = data_r_ch.median("epoch") / data_r_ch.reduce(scipy.stats.iqr, dim="epoch") - trim = self.config['bridge']['bridge_trim'] + trim = self.config["bridge"]["bridge_trim"] if trim >= 1: trim /= 100 trim /= 2 - trim_mean = partial(scipy.stats.mstats.trimmed_mean, - limits=(trim, trim)) - trim_std = partial(scipy.stats.mstats.trimmed_std, - limits=(trim, trim)) + trim_mean = partial(scipy.stats.mstats.trimmed_mean, limits=(trim, trim)) + trim_std = partial(scipy.stats.mstats.trimmed_std, limits=(trim, trim)) - z_val = self.config['bridge']['bridge_z'] - mask = (msr > msr.reduce(trim_mean, dim="ch") - + z_val*msr.reduce(trim_std, dim="ch") - ) + z_val = self.config["bridge"]["bridge_z"] + mask = msr > msr.reduce(trim_mean, dim="ch") + z_val * msr.reduce( + trim_std, dim="ch" + ) bad_ch_names = data_r_ch.ch.values[mask] - logger.info(f'📋 LOSSLESS: Bridged channels: {bad_ch_names}') - self.flags["ch"].add_flag_cat(kind='bridge', - bad_ch_names=bad_ch_names) + logger.info(f"📋 LOSSLESS: Bridged channels: {bad_ch_names}") + self.flags["ch"].add_flag_cat(kind="bridge", bad_ch_names=bad_ch_names) @lossless_logger def flag_ch_rank(self, data_r_ch): @@ -883,17 +911,16 @@ def flag_ch_rank(self, data_r_ch): an instance of `numpy.array`. """ if len(self.flags["ch"].get_flagged()): - ch_sel = [ch for ch in data_r_ch.ch.values - if ch not in self.flags["ch"].get_flagged()] + ch_sel = [ + ch + for ch in data_r_ch.ch.values + if ch not in self.flags["ch"].get_flagged() + ] data_r_ch = data_r_ch.sel(ch=ch_sel) - bad_ch_names = [str(data_r_ch.median("epoch") - .idxmax(dim="ch") - .to_numpy() - )] - logger.info(f'📋 LOSSLESS: Rank channel: {bad_ch_names}') - self.flags["ch"].add_flag_cat(kind='rank', - bad_ch_names=bad_ch_names) + bad_ch_names = [str(data_r_ch.median("epoch").idxmax(dim="ch").to_numpy())] + logger.info(f"📋 LOSSLESS: Rank channel: {bad_ch_names}") + self.flags["ch"].add_flag_cat(kind="rank", bad_ch_names=bad_ch_names) @lossless_logger def flag_epoch_low_r(self): @@ -909,18 +936,15 @@ def flag_epoch_low_r(self): # non-flagged channels and epochs... data_r_ch, epochs = self.get_n_nbr() - bad_epoch_inds = _detect_outliers(data_r_ch, flag_dim='epoch', - init_dir='neg', - **self.config['epoch_low_r']) - logger.info(f'📋 LOSSLESS: Uncorrelated epochs: {bad_epoch_inds}') - self.flags["epoch"].add_flag_cat('low_r', - bad_epoch_inds, - epochs) + bad_epoch_inds = _detect_outliers( + data_r_ch, flag_dim="epoch", init_dir="neg", **self.config["epoch_low_r"] + ) + logger.info(f"📋 LOSSLESS: Uncorrelated epochs: {bad_epoch_inds}") + self.flags["epoch"].add_flag_cat("low_r", bad_epoch_inds, epochs) def flag_epoch_gap(self): """Flag small time periods between pylossless annotations.""" - annots = marks_flag_gap(self.raw, - self.config['epoch_gap']['min_gap_ms']) + annots = marks_flag_gap(self.raw, self.config["epoch_gap"]["min_gap_ms"]) self.raw.set_annotations(self.raw.annotations + annots) @lossless_logger @@ -934,18 +958,18 @@ def run_ica(self, run): epochs, 'run2' is the final ICA used to classify components with `mne_icalabel`. """ - ica_kwargs = self.config['ica']['ica_args'][run] - if 'max_iter' not in ica_kwargs: - ica_kwargs['max_iter'] = 'auto' - if 'random_state' not in ica_kwargs: - ica_kwargs['random_state'] = 97 + ica_kwargs = self.config["ica"]["ica_args"][run] + if "max_iter" not in ica_kwargs: + ica_kwargs["max_iter"] = "auto" + if "random_state" not in ica_kwargs: + ica_kwargs["random_state"] = 97 epochs = self.get_epochs() - if run == 'run1': + if run == "run1": self.ica1 = ICA(**ica_kwargs) self.ica1.fit(epochs) - elif run == 'run2': + elif run == "run2": self.ica2 = ICA(**ica_kwargs) self.ica2.fit(epochs) self.flags["ic"].label_components(epochs, self.ica2) @@ -961,15 +985,13 @@ def flag_epoch_ic_sd1(self): # Calculate IC sd by window epochs = self.get_epochs() epochs_xr = epochs_to_xr(epochs, kind="ic", ica=self.ica1) - data_sd = epochs_xr.std('time') + data_sd = epochs_xr.std("time") # Create the windowing sd criteria - kwargs = self.config['ica']['ic_ic_sd'] - bad_epoch_inds = _detect_outliers(data_sd, - flag_dim='epoch', **kwargs) + kwargs = self.config["ica"]["ic_ic_sd"] + bad_epoch_inds = _detect_outliers(data_sd, flag_dim="epoch", **kwargs) - self.flags["epoch"].add_flag_cat('ic_sd1', bad_epoch_inds, - epochs) + self.flags["epoch"].add_flag_cat("ic_sd1", bad_epoch_inds, epochs) # icsd_epoch_flags=padflags(raw, icsd_epoch_flags,1,'value',.5); @@ -983,65 +1005,68 @@ def save(self, derivatives_path, overwrite=False): overwrite : bool (default False) whether to overwrite existing files with the same name. """ - mne_bids.write_raw_bids(self.raw, - derivatives_path, - overwrite=overwrite, - format='EDF', - allow_preload=True) + mne_bids.write_raw_bids( + self.raw, + derivatives_path, + overwrite=overwrite, + format="EDF", + allow_preload=True, + ) # TODO: address derivatives support in MNE bids. # use shutils ( or pathlib?) to rename file with ll suffix # Save ICAs bpath = derivatives_path.copy() - for this_ica, self_ica, in zip(['ica1', 'ica2'], - [self.ica1, self.ica2]): - suffix = this_ica + '_ica' - ica_bidspath = bpath.update(extension='.fif', - suffix=suffix, - check=False) + for ( + this_ica, + self_ica, + ) in zip(["ica1", "ica2"], [self.ica1, self.ica2]): + suffix = this_ica + "_ica" + ica_bidspath = bpath.update(extension=".fif", suffix=suffix, check=False) self_ica.save(ica_bidspath, overwrite=overwrite) # Save IC labels - iclabels_bidspath = bpath.update(extension='.tsv', - suffix='iclabels', - check=False) + iclabels_bidspath = bpath.update( + extension=".tsv", suffix="iclabels", check=False + ) self.flags["ic"].save_tsv(iclabels_bidspath) # TODO: epoch marks and ica marks are not currently saved into annots # raw.save(derivatives_path, overwrite=True, split_naming='bids') - config_bidspath = bpath.update(extension='.yaml', - suffix='ll_config', - check=False) + config_bidspath = bpath.update( + extension=".yaml", suffix="ll_config", check=False + ) self.config.save(config_bidspath) # Save flag["ch"] - flagged_chs_fpath = bpath.update(extension='.tsv', - suffix='ll_FlaggedChs', - check=False) + flagged_chs_fpath = bpath.update( + extension=".tsv", suffix="ll_FlaggedChs", check=False + ) self.flags["ch"].save_tsv(flagged_chs_fpath.fpath) @lossless_logger def filter(self): """Run filter procedure based on structured config args.""" # 5.a. Filter lowpass/highpass - self.raw.filter(**self.config['filtering']['filter_args']) + self.raw.filter(**self.config["filtering"]["filter_args"]) - if 'notch_filter_args' in self.config['filtering']: - notch_args = self.config['filtering']['notch_filter_args'] + if "notch_filter_args" in self.config["filtering"]: + notch_args = self.config["filtering"]["notch_filter_args"] # in raw.notch_filter, freqs=None is ok if method=spectrum_fit - if not notch_args['freqs'] and 'method' not in notch_args: - logger.info('No notch filter arguments provided. Skipping') + if not notch_args["freqs"] and "method" not in notch_args: + logger.info("No notch filter arguments provided. Skipping") else: self.raw.notch_filter(**notch_args) # 5.b. Filter notch - notch_args = self.config['filtering']['notch_filter_args'] - spectrum_fit_method = ('method' in notch_args and - notch_args['method'] == 'spectrum_fit') - if notch_args['freqs'] or spectrum_fit_method: + notch_args = self.config["filtering"]["notch_filter_args"] + spectrum_fit_method = ( + "method" in notch_args and notch_args["method"] == "spectrum_fit" + ) + if notch_args["freqs"] or spectrum_fit_method: # in raw.notch_filter, freqs=None is ok if method=='spectrum_fit' self.raw.notch_filter(**notch_args) else: - logger.info('No notch filter arguments provided. Skipping') + logger.info("No notch filter arguments provided. Skipping") def run(self, bids_path, save=True, overwrite=False): """Run the pylossless pipeline. @@ -1074,7 +1099,6 @@ def run_with_raw(self, raw): @lossless_time def _run(self): - # Make sure sampling frequency is an integer self._check_sfreq() @@ -1097,19 +1121,16 @@ def _run(self): self.flag_ch_sd_epoch(message="Flagging Noisy Time periods") # 5. Filtering - self.filter(message='Filtering') + self.filter(message="Filtering") # 6. calculate nearest neighbort r values - data_r_ch = self.flag_ch_low_r(message="Flagging uncorrelated" - " channels") + data_r_ch = self.flag_ch_low_r(message="Flagging uncorrelated" " channels") # 7. Identify bridged channels - self.flag_ch_bridge(data_r_ch, - message="Flagging Bridged channels") + self.flag_ch_bridge(data_r_ch, message="Flagging Bridged channels") # 8. Flag rank channels - self.flag_ch_rank(data_r_ch, - message="Flagging the rank channel") + self.flag_ch_rank(data_r_ch, message="Flagging the rank channel") # 9. Calculate nearest neighbour R values for epochs self.flag_epoch_low_r(message="Flagging Uncorrelated epochs") @@ -1118,14 +1139,13 @@ def _run(self): self.flag_epoch_gap() # 11. Run ICA - self.run_ica('run1', message="Running Initial ICA") + self.run_ica("run1", message="Running Initial ICA") # 12. Calculate IC SD - self.flag_epoch_ic_sd1(message="Flagging time periods with noisy" - " IC's.") + self.flag_epoch_ic_sd1(message="Flagging time periods with noisy" " IC's.") # 13. TODO: integrate labels from IClabels to self.flags["ic"] - self.run_ica('run2', message="Running Final ICA.") + self.run_ica("run2", message="Running Final ICA.") # 14. Flag very small time periods between flagged time self.flag_epoch_gap() @@ -1150,26 +1170,26 @@ def load_ll_derivative(self, derivatives_path): self.raw = mne_bids.read_raw_bids(derivatives_path) bpath = derivatives_path.copy() # Load ICAs - for this_ica in ['ica1', 'ica2']: - suffix = this_ica + '_ica' - ica_bidspath = bpath.update(extension='.fif', suffix=suffix, - check=False) - setattr(self, this_ica, - mne.preprocessing.read_ica(ica_bidspath.fpath)) + for this_ica in ["ica1", "ica2"]: + suffix = this_ica + "_ica" + ica_bidspath = bpath.update(extension=".fif", suffix=suffix, check=False) + setattr(self, this_ica, mne.preprocessing.read_ica(ica_bidspath.fpath)) # Load IC labels - iclabels_bidspath = bpath.update(extension='.tsv', suffix='iclabels', - check=False) + iclabels_bidspath = bpath.update( + extension=".tsv", suffix="iclabels", check=False + ) self.flags["ic"].load_tsv(iclabels_bidspath.fpath) - self.config_fname = bpath.update(extension='.yaml', suffix='ll_config', - check=False) + self.config_fname = bpath.update( + extension=".yaml", suffix="ll_config", check=False + ) self.load_config() # Load Flagged Chs - flagged_chs_fpath = bpath.update(extension='.tsv', - suffix='ll_FlaggedChs', - check=False) + flagged_chs_fpath = bpath.update( + extension=".tsv", suffix="ll_FlaggedChs", check=False + ) self.flags["ch"].load_tsv(flagged_chs_fpath.fpath) # Load Flagged Epochs @@ -1178,11 +1198,11 @@ def load_ll_derivative(self, derivatives_path): return self # TODO: Finish docstring - def get_derivative_path(self, bids_path, derivative_name='pylossless'): + def get_derivative_path(self, bids_path, derivative_name="pylossless"): """Build derivative path for file.""" lossless_suffix = bids_path.suffix if bids_path.suffix else "" - lossless_suffix += '_ll' - lossless_root = bids_path.root / 'derivatives' / derivative_name - return bids_path.copy().update(suffix=lossless_suffix, - root=lossless_root, - check=False) + lossless_suffix += "_ll" + lossless_root = bids_path.root / "derivatives" / derivative_name + return bids_path.copy().update( + suffix=lossless_suffix, root=lossless_root, check=False + ) diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index 2ac785f..90362cc 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -12,87 +12,93 @@ def load_openneuro_bids(): + """Load a BIDS dataset from OpenNeuro.""" config = ll.config.Config() config.load_default() - config['project']['bids_montage'] = '' - config['project']['analysis_montage'] = 'standard_1020' - config['project']['set_montage_kwargs']['on_missing'] = 'warn' + config["project"]["bids_montage"] = "" + config["project"]["analysis_montage"] = "standard_1020" + config["project"]["set_montage_kwargs"]["on_missing"] = "warn" # Shamelessly copied from # https://mne.tools/mne-bids/stable/auto_examples/read_bids_datasets.html # pip install openneuro-py - dataset = 'ds002778' - subject = 'pd6' + dataset = "ds002778" + subject = "pd6" # Download one subject's data from each dataset - bids_root = Path('.') / dataset + bids_root = Path(".") / dataset # TODO: Delete this directory after test otherwise MNE will think the # sample directory is outdated, and will re-download it the next time # data_path() is called, which is annoying for users. bids_root.mkdir(exist_ok=True) - openneuro.download(dataset=dataset, target_dir=bids_root, - include=[f'sub-{subject}']) - - datatype = 'eeg' - session = 'off' - task = 'rest' - suffix = 'eeg' - bids_path = mne_bids.BIDSPath(subject=subject, session=session, task=task, - suffix=suffix, datatype=datatype, - root=bids_root) - - while not bids_path.fpath.with_suffix('.bdf').exists(): - print(list(bids_path.fpath.glob('*'))) + openneuro.download( + dataset=dataset, target_dir=bids_root, include=[f"sub-{subject}"] + ) + + datatype = "eeg" + session = "off" + task = "rest" + suffix = "eeg" + bids_path = mne_bids.BIDSPath( + subject=subject, + session=session, + task=task, + suffix=suffix, + datatype=datatype, + root=bids_root, + ) + + while not bids_path.fpath.with_suffix(".bdf").exists(): + print(list(bids_path.fpath.glob("*"))) sleep(1) raw = mne_bids.read_raw_bids(bids_path) - annots = mne.Annotations(onset=[1, 15], - duration=[1, 1], - description=['test_annot', 'test_annot']) + annots = mne.Annotations( + onset=[1, 15], duration=[1, 1], description=["test_annot", "test_annot"] + ) raw.set_annotations(annots) return raw, config, bids_root # @pytest.mark.xfail -@pytest.mark.parametrize('dataset, find_breaks', [('openneuro', True), - ('openneuro', False)]) +@pytest.mark.parametrize( + "dataset, find_breaks", [("openneuro", True), ("openneuro", False)] +) def test_pipeline_run(dataset, find_breaks): - """test running the pipeline.""" - if dataset == 'openneuro': + """Test running the pipeline.""" + if dataset == "openneuro": raw, config, bids_root = load_openneuro_bids() if find_breaks: - config['find_breaks'] = {} - config['find_breaks']['min_break_duration'] = 9 - config['find_breaks']['t_start_after_previous'] = 1 - config['find_breaks']['t_stop_before_next'] = 0 + config["find_breaks"] = {} + config["find_breaks"]["min_break_duration"] = 9 + config["find_breaks"]["t_start_after_previous"] = 1 + config["find_breaks"]["t_stop_before_next"] = 0 config.save("test_config.yaml") - pipeline = ll.LosslessPipeline('test_config.yaml') - not_in_1020 = ['EXG1', 'EXG2', 'EXG3', 'EXG4', - 'EXG5', 'EXG6', 'EXG7', 'EXG8'] - pipeline.raw = raw.pick('eeg', - exclude=not_in_1020).load_data() + pipeline = ll.LosslessPipeline("test_config.yaml") + not_in_1020 = ["EXG1", "EXG2", "EXG3", "EXG4", "EXG5", "EXG6", "EXG7", "EXG8"] + pipeline.raw = raw.pick("eeg", exclude=not_in_1020).load_data() pipeline.run_with_raw(pipeline.raw) if find_breaks: - assert 'BAD_break' in raw.annotations.description + assert "BAD_break" in raw.annotations.description - Path('test_config.yaml').unlink() # delete config file + Path("test_config.yaml").unlink() # delete config file shutil.rmtree(bids_root) -@pytest.mark.parametrize('logging', [True, False]) +@pytest.mark.parametrize("logging", [True, False]) def test_find_breaks(logging): """Make sure MNE's annotate_break function can run.""" testing_path = mne.datasets.testing.data_path() - fname = testing_path / 'EDF' / 'test_edf_overlapping_annotations.edf' + fname = testing_path / "EDF" / "test_edf_overlapping_annotations.edf" raw = mne.io.read_raw_edf(fname, preload=True) config_fname = "find_breaks_config.yaml" config = ll.config.Config() config.load_default() - config['find_breaks'] = {} - config['find_breaks']['min_break_duration'] = 15 + config["find_breaks"] = {} + config["find_breaks"]["min_break_duration"] = 15 config.save(config_fname) pipeline = ll.LosslessPipeline(config_fname) pipeline.raw = raw diff --git a/pylossless/tests/test_simulated.py b/pylossless/tests/test_simulated.py index a0833ad..2cf84b9 100644 --- a/pylossless/tests/test_simulated.py +++ b/pylossless/tests/test_simulated.py @@ -16,13 +16,13 @@ # LOAD DATA data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" # make BIDS object bpath = mne_bids.get_bids_path_from_fname(raw_fname, check=False) -bpath.suffix = 'sample_audvis_raw' +bpath.suffix = "sample_audvis_raw" # Load real data as the template @@ -32,41 +32,45 @@ # GENERATE DIPOLE TIME SERIES n_dipoles = 4 # number of dipoles to create -epoch_duration = 2. # duration of each epoch/event +epoch_duration = 2.0 # duration of each epoch/event n = 0 # harmonic number rng = np.random.RandomState(0) # random state (make reproducible) np.random.seed(5) def data_fun(times): - """Generate time-staggered sinusoids at harmonics of 10Hz""" + """Generate time-staggered sinusoids at harmonics of 10Hz.""" global n n_samp = len(times) window = np.zeros(n_samp) - start, stop = [int(ii * float(n_samp) / (2 * n_dipoles)) - for ii in (2 * n, 2 * n + 1)] - window[start:stop] = 1. + start, stop = [ + int(ii * float(n_samp) / (2 * n_dipoles)) for ii in (2 * n, 2 * n + 1) + ] + window[start:stop] = 1.0 n += 1 - data = 25e-9 * np.sin(2. * np.pi * 10. * n * times) + data = 25e-9 * np.sin(2.0 * np.pi * 10.0 * n * times) data *= window return data -times = raw.times[:int(raw.info['sfreq'] * epoch_duration)] +times = raw.times[: int(raw.info["sfreq"] * epoch_duration)] fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] -stc = simulate_sparse_stc(src, n_dipoles=n_dipoles, times=times, - data_fun=data_fun, random_state=rng) +src = fwd["src"] +stc = simulate_sparse_stc( + src, n_dipoles=n_dipoles, times=times, data_fun=data_fun, random_state=rng +) # SIMULATE RAW DATA raw_sim = simulate_raw(raw.info, [stc] * 10, forward=fwd, verbose=True) -raw_sim.pick('eeg') +raw_sim.pick("eeg") # Save Info and Montage for later re-use montage = raw_sim.get_montage() -info = mne.create_info(ch_names=raw_sim.ch_names, - sfreq=raw_sim.info['sfreq'], - ch_types=raw_sim.get_channel_types()) +info = mne.create_info( + ch_names=raw_sim.ch_names, + sfreq=raw_sim.info["sfreq"], + ch_types=raw_sim.get_channel_types(), +) # MAKE A VERY NOISY TIME PERIOD @@ -74,11 +78,10 @@ def data_fun(times): raw_selection2 = raw_sim.copy().crop(tmin=2, tmax=3, include_tmax=False) raw_selection3 = raw_sim.copy().crop(tmin=3, tmax=19.994505956825666) -cov_noisy_period = make_ad_hoc_cov(raw_selection2.info, std=dict(eeg=.000002)) -add_noise(raw_selection2, - cov_noisy_period, - iir_filter=[0.2, -0.2, 0.04], - random_state=rng) +cov_noisy_period = make_ad_hoc_cov(raw_selection2.info, std=dict(eeg=0.000002)) +add_noise( + raw_selection2, cov_noisy_period, iir_filter=[0.2, -0.2, 0.04], random_state=rng +) raw_selection1.append([raw_selection2, raw_selection3]) raw_selection1.set_annotations(None) @@ -88,25 +91,24 @@ def data_fun(times): cov = make_ad_hoc_cov(raw_sim.info) add_noise(raw_sim, cov, iir_filter=[0.2, -0.2, 0.04], random_state=rng) -make_these_noisy = ['EEG 001', 'EEG 003'] -cov_noisy = make_ad_hoc_cov(raw_sim.copy().pick(make_these_noisy).info, - std=dict(eeg=.000002)) +make_these_noisy = ["EEG 001", "EEG 003"] +cov_noisy = make_ad_hoc_cov( + raw_sim.copy().pick(make_these_noisy).info, std=dict(eeg=0.000002) +) add_noise(raw_sim, cov_noisy, iir_filter=[0.2, -0.2, 0.04], random_state=rng) # MAKE LESS NOISY CHANNELS -make_these_noisy = ['EEG 005', 'EEG 007'] +make_these_noisy = ["EEG 005", "EEG 007"] raw_selection1 = raw_sim.copy().crop(tmin=0, tmax=8, include_tmax=False) raw_selection2 = raw_sim.copy().crop(tmin=8, tmax=19.994505956825666) -cov_less_noisy = make_ad_hoc_cov((raw_selection1.copy() - .pick(make_these_noisy) - .info), - std=dict(eeg=.0000008)) -add_noise(raw_selection1, - cov_less_noisy, - iir_filter=[0.2, -0.2, 0.04], - random_state=rng) +cov_less_noisy = make_ad_hoc_cov( + (raw_selection1.copy().pick(make_these_noisy).info), std=dict(eeg=0.0000008) +) +add_noise( + raw_selection1, cov_less_noisy, iir_filter=[0.2, -0.2, 0.04], random_state=rng +) raw_selection1.append([raw_selection2]) raw_selection1.set_annotations(None) raw_sim = raw_selection1 @@ -117,10 +119,8 @@ def data_fun(times): # Make the last channel random. save for later use min_val = data[23, :].min() -max_val = data[23, :].min() + .0000065 -low_correlated_ch = np.random.uniform(low=min_val, - high=max_val, - size=len(data[23, :])) +max_val = data[23, :].min() + 0.0000065 +low_correlated_ch = np.random.uniform(low=min_val, high=max_val, size=len(data[23, :])) # MAKE AN UNCORRELATED CH data[23] = low_correlated_ch @@ -145,67 +145,68 @@ def data_fun(times): config.load_default() # CUSTOMIZE CONFIG -config['ch_ch_sd']['outliers_kwargs']['k'] = 3 -config['ch_ch_sd']['outliers_kwargs']['lower'] = .15 -config['ch_ch_sd']['outliers_kwargs']['upper'] = .85 +config["ch_ch_sd"]["outliers_kwargs"]["k"] = 3 +config["ch_ch_sd"]["outliers_kwargs"]["lower"] = 0.15 +config["ch_ch_sd"]["outliers_kwargs"]["upper"] = 0.85 -config['epoch_ch_sd']['outliers_kwargs']['k'] = 3 -config['epoch_ch_sd']['outliers_kwargs']['lower'] = .15 -config['epoch_ch_sd']['outliers_kwargs']['upper'] = .85 +config["epoch_ch_sd"]["outliers_kwargs"]["k"] = 3 +config["epoch_ch_sd"]["outliers_kwargs"]["lower"] = 0.15 +config["epoch_ch_sd"]["outliers_kwargs"]["upper"] = 0.85 -config['ch_low_r']['outliers_kwargs']['k'] = 2 -config['ch_low_r']['outliers_kwargs']['lower'] = .23 -config['ch_low_r']['outliers_kwargs']['upper'] = .85 -config['ch_low_r']['flag_crit'] = .25 +config["ch_low_r"]["outliers_kwargs"]["k"] = 2 +config["ch_low_r"]["outliers_kwargs"]["lower"] = 0.23 +config["ch_low_r"]["outliers_kwargs"]["upper"] = 0.85 +config["ch_low_r"]["flag_crit"] = 0.25 -config['epoch_low_r']['outliers_kwargs']['k'] = 3 -config['epoch_low_r']['outliers_kwargs']['lower'] = .15 -config['epoch_low_r']['outliers_kwargs']['upper'] = .85 +config["epoch_low_r"]["outliers_kwargs"]["k"] = 3 +config["epoch_low_r"]["outliers_kwargs"]["lower"] = 0.15 +config["epoch_low_r"]["outliers_kwargs"]["upper"] = 0.85 config.save("project_ll_config_face13_egi.yaml") -pipeline = ll.LosslessPipeline('project_ll_config_face13_egi.yaml') +pipeline = ll.LosslessPipeline("project_ll_config_face13_egi.yaml") config.save("sample_audvis_config.yaml") # GENERATE PIPELINE -pipeline = ll.LosslessPipeline('sample_audvis_config.yaml') +pipeline = ll.LosslessPipeline("sample_audvis_config.yaml") pipeline.raw = raw_sim # TEST -@pytest.mark.parametrize('pipeline', - [(pipeline)]) +@pytest.mark.parametrize("pipeline", [(pipeline)]) def test_simulated_raw(pipeline): + """Test pipeline on simulated EEG.""" pipeline._check_sfreq() # This file should have been downsampled - assert pipeline.raw.info['sfreq'] == 600 + assert pipeline.raw.info["sfreq"] == 600 # FIND NOISY EPOCHS pipeline.flag_ch_sd_epoch() # Epoch 2 was made noisy and should be flagged. - assert np.array_equal(pipeline.flags['epoch']['ch_sd'], [2]) + assert np.array_equal(pipeline.flags["epoch"]["ch_sd"], [2]) epochs = pipeline.get_epochs() # only epoch at indice 2 should have been dropped assert all(not tup or i == 2 for i, tup in enumerate(epochs.drop_log)) # RUN FLAG_CH_SD pipeline.flag_ch_sd_ch() - noisy_chs = ['EEG 001', 'EEG 003', 'EEG 005', 'EEG 007'] - assert np.array_equal(pipeline.flags['ch']['ch_sd'], noisy_chs) + noisy_chs = ["EEG 001", "EEG 003", "EEG 005", "EEG 007"] + assert np.array_equal(pipeline.flags["ch"]["ch_sd"], noisy_chs) # FIND UNCORRELATED CHS data_r_ch = pipeline.flag_ch_low_r() # Previously flagged chs should not be in the correlation array - assert all([name not in data_r_ch.coords['ch'] - for name in pipeline.flags['ch']['ch_sd']]) + assert all( + [name not in data_r_ch.coords["ch"] for name in pipeline.flags["ch"]["ch_sd"]] + ) # EEG 024 was made random and should be flagged. - assert ['EEG 024'] in pipeline.flags['ch']['low_r'] + assert ["EEG 024"] in pipeline.flags["ch"]["low_r"] # RUN FLAG_CH_BRIDGE data_r_ch = pipeline.flag_ch_low_r() pipeline.flag_ch_bridge(data_r_ch) # Channels below are duplicates and should be flagged. - assert 'EEG 053' in pipeline.flags['ch']['bridge'] - assert 'EEG 054' in pipeline.flags['ch']['bridge'] + assert "EEG 053" in pipeline.flags["ch"]["bridge"] + assert "EEG 054" in pipeline.flags["ch"]["bridge"] # Delete temp config file tmp_config_fname = Path(pipeline.config_fname).absolute() diff --git a/setup.py b/setup.py index d5df8bf..a8fb55d 100644 --- a/setup.py +++ b/setup.py @@ -10,13 +10,13 @@ from pathlib import Path from setuptools import setup, find_packages -with Path('requirements.txt').open() as f: +with Path("requirements.txt").open() as f: requirements = f.read().splitlines() extras = { - 'dash': 'requirements_qc.txt', - 'test': 'requirements_testing.txt', - 'doc': './docs/requirements_doc.txt' + "dash": "requirements_qc.txt", + "test": "requirements_testing.txt", + "doc": "./docs/requirements_doc.txt", } extras_require = {} @@ -30,17 +30,17 @@ qc_entry_point = ["pylossless_qc=pylossless.dash.pylossless_qc:main"] setup( - name='pylossless', - version='0.1.0', - description='Lossless EEG Processing Pipeline Built on MNE and Dash', + name="pylossless", + version="0.1.0", + description="Lossless EEG Processing Pipeline Built on MNE and Dash", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", author="Scott Huberty", - author_email='seh33@uw.edu', - url='https://github.com/lina-usc/pylossless', + author_email="seh33@uw.edu", + url="https://github.com/lina-usc/pylossless", packages=find_packages(), install_requires=requirements, extras_require=extras_require, include_package_data=True, - entry_points={"console_scripts": qc_entry_point} + entry_points={"console_scripts": qc_entry_point}, ) From 256c38f5ecbe39390326e3764d407eed0e615e67 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 1 Sep 2023 12:37:12 -0400 Subject: [PATCH 04/17] TST: Error on warnings in CI, fix warnings in code - use mne.util. warn to accurately raise warnings --- pylossless/pipeline.py | 10 +++++----- pylossless/tests/test_simulated.py | 8 +------- pyproject.toml | 7 ++++++- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index dc989d2..78bc3ac 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -24,7 +24,7 @@ from mne.preprocessing import annotate_break from mne.preprocessing import ICA from mne.coreg import Coregistration -from mne.utils import logger +from mne.utils import logger, warn import mne_bids from mne_bids import get_bids_path_from_fname, BIDSPath @@ -516,11 +516,11 @@ def _check_sfreq(self): if not sfreq.is_integer(): # we can't use f-strings in the logging module msg = ( - "The Raw sampling frequency is %.2f. a non-integer" - " sampling frequency can cause incorrect mapping of epochs " - "to annotations. downsampling to %d" % (sfreq, int(sfreq)) + f"The Raw sampling frequency is {sfreq:.2f}. a non-integer " + f"sampling frequency can cause incorrect mapping of epochs " + f"to annotations. downsampling to {int(sfreq)}" ) - logger.warn(msg) + warn(msg) self.raw.resample(int(sfreq)) return self.raw diff --git a/pylossless/tests/test_simulated.py b/pylossless/tests/test_simulated.py index 2cf84b9..9be05f4 100644 --- a/pylossless/tests/test_simulated.py +++ b/pylossless/tests/test_simulated.py @@ -10,8 +10,6 @@ from mne.simulation import simulate_sparse_stc, simulate_raw, add_noise -import mne_bids - import pylossless as ll # LOAD DATA @@ -20,13 +18,9 @@ raw_fname = meg_path / "sample_audvis_raw.fif" fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" -# make BIDS object -bpath = mne_bids.get_bids_path_from_fname(raw_fname, check=False) -bpath.suffix = "sample_audvis_raw" - # Load real data as the template -raw = mne_bids.read_raw_bids(bpath) +raw = mne.io.read_raw_fif(raw_fname) raw.set_eeg_reference(projection=True) diff --git a/pyproject.toml b/pyproject.toml index 501abeb..f0e14d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,4 +18,9 @@ ignore-decorators = [ ] [tool.black] -exclude = "(dist/)|(build/)|(.*\\.ipynb)" # Exclude build artifacts and notebooks \ No newline at end of file +exclude = "(dist/)|(build/)|(.*\\.ipynb)" # Exclude build artifacts and notebooks + +[tool.pytest.ini_options] +filterwarnings = ["error", +# error on warning except the non-int sample freq warning, which we want to be raised +'ignore:The Raw sampling frequency is',] \ No newline at end of file From faa694808677ac123e99bfd57e9ad3f3978bdd39 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 1 Sep 2023 14:49:29 -0400 Subject: [PATCH 05/17] TST: update test_simulated cleaned up, and dropped low correlation assertion for now. commented the github issue number for reference --- pylossless/tests/test_simulated.py | 91 ++++++++++-------------------- 1 file changed, 29 insertions(+), 62 deletions(-) diff --git a/pylossless/tests/test_simulated.py b/pylossless/tests/test_simulated.py index 9be05f4..f955e99 100644 --- a/pylossless/tests/test_simulated.py +++ b/pylossless/tests/test_simulated.py @@ -3,12 +3,19 @@ from pathlib import Path import numpy as np +from numpy.testing import assert_array_equal import mne from mne import make_ad_hoc_cov from mne.datasets import sample -from mne.simulation import simulate_sparse_stc, simulate_raw, add_noise +from mne.simulation import ( + simulate_sparse_stc, + simulate_raw, + add_noise, + add_ecg, + add_eog, +) import pylossless as ll @@ -56,6 +63,10 @@ def data_fun(times): # SIMULATE RAW DATA raw_sim = simulate_raw(raw.info, [stc] * 10, forward=fwd, verbose=True) +cov = make_ad_hoc_cov(raw_sim.info) +add_noise(raw_sim, cov, iir_filter=[0.2, -0.2, 0.04], random_state=rng) +add_ecg(raw_sim, random_state=rng) +add_eog(raw_sim, random_state=rng) raw_sim.pick("eeg") # Save Info and Montage for later re-use @@ -82,44 +93,27 @@ def data_fun(times): raw_sim = raw_selection1 # MAKE SOME VERY NOISY CHANNELS -cov = make_ad_hoc_cov(raw_sim.info) -add_noise(raw_sim, cov, iir_filter=[0.2, -0.2, 0.04], random_state=rng) - -make_these_noisy = ["EEG 001", "EEG 003"] -cov_noisy = make_ad_hoc_cov( - raw_sim.copy().pick(make_these_noisy).info, std=dict(eeg=0.000002) -) -add_noise(raw_sim, cov_noisy, iir_filter=[0.2, -0.2, 0.04], random_state=rng) -# MAKE LESS NOISY CHANNELS -make_these_noisy = ["EEG 005", "EEG 007"] +make_these_noisy = ["EEG 001", "EEG 002"] +raw_noisy_chs = raw_sim.copy().pick(make_these_noisy) +cov_noisy = make_ad_hoc_cov(raw_noisy_chs.info, std=dict(eeg=0.000002)) +add_noise(raw_noisy_chs, cov_noisy, iir_filter=[0.2, -0.2, 0.04], random_state=rng) -raw_selection1 = raw_sim.copy().crop(tmin=0, tmax=8, include_tmax=False) -raw_selection2 = raw_sim.copy().crop(tmin=8, tmax=19.994505956825666) - -cov_less_noisy = make_ad_hoc_cov( - (raw_selection1.copy().pick(make_these_noisy).info), std=dict(eeg=0.0000008) -) -add_noise( - raw_selection1, cov_less_noisy, iir_filter=[0.2, -0.2, 0.04], random_state=rng -) -raw_selection1.append([raw_selection2]) -raw_selection1.set_annotations(None) -raw_sim = raw_selection1 +raw_sim.drop_channels(make_these_noisy) +raw_noisy_chs.add_channels([raw_sim], force_update_info=True) +raw_sim = raw_noisy_chs # MAKE BRIDGED CHANNELS AND 1 FLAT CHANNEL data = raw_sim.get_data() # ch x times data[52, :] = data[53, :] # duplicate ch 53 and 54 -# Make the last channel random. save for later use -min_val = data[23, :].min() -max_val = data[23, :].min() + 0.0000065 -low_correlated_ch = np.random.uniform(low=min_val, high=max_val, size=len(data[23, :])) - # MAKE AN UNCORRELATED CH -data[23] = low_correlated_ch -# Shuffle it Again. -np.random.shuffle(data[23]) +min_val = data[28, :].min() +max_val = data[28, :].min() + 0.0000065 +low_correlated_ch = np.random.uniform(low=min_val, high=max_val, size=len(data[28, :])) +data[28] = low_correlated_ch +# Shuffle it Again. in-place +# np.random.shuffle(data[23]) # Make new raw out of data raw_sim = mne.io.RawArray(data, info) @@ -131,36 +125,10 @@ def data_fun(times): # Re-set the montage raw_sim.set_montage(montage) - # LOAD DEFAULT CONFIG config = ll.config.Config() config.load_default() -config = ll.config.Config() -config.load_default() - -# CUSTOMIZE CONFIG -config["ch_ch_sd"]["outliers_kwargs"]["k"] = 3 -config["ch_ch_sd"]["outliers_kwargs"]["lower"] = 0.15 -config["ch_ch_sd"]["outliers_kwargs"]["upper"] = 0.85 - -config["epoch_ch_sd"]["outliers_kwargs"]["k"] = 3 -config["epoch_ch_sd"]["outliers_kwargs"]["lower"] = 0.15 -config["epoch_ch_sd"]["outliers_kwargs"]["upper"] = 0.85 - -config["ch_low_r"]["outliers_kwargs"]["k"] = 2 -config["ch_low_r"]["outliers_kwargs"]["lower"] = 0.23 -config["ch_low_r"]["outliers_kwargs"]["upper"] = 0.85 -config["ch_low_r"]["flag_crit"] = 0.25 - -config["epoch_low_r"]["outliers_kwargs"]["k"] = 3 -config["epoch_low_r"]["outliers_kwargs"]["lower"] = 0.15 -config["epoch_low_r"]["outliers_kwargs"]["upper"] = 0.85 - -config.save("project_ll_config_face13_egi.yaml") - -pipeline = ll.LosslessPipeline("project_ll_config_face13_egi.yaml") config.save("sample_audvis_config.yaml") - # GENERATE PIPELINE pipeline = ll.LosslessPipeline("sample_audvis_config.yaml") pipeline.raw = raw_sim @@ -176,15 +144,15 @@ def test_simulated_raw(pipeline): # FIND NOISY EPOCHS pipeline.flag_ch_sd_epoch() # Epoch 2 was made noisy and should be flagged. - assert np.array_equal(pipeline.flags["epoch"]["ch_sd"], [2]) + assert_array_equal(pipeline.flags["epoch"]["ch_sd"], [2]) epochs = pipeline.get_epochs() # only epoch at indice 2 should have been dropped assert all(not tup or i == 2 for i, tup in enumerate(epochs.drop_log)) # RUN FLAG_CH_SD pipeline.flag_ch_sd_ch() - noisy_chs = ["EEG 001", "EEG 003", "EEG 005", "EEG 007"] - assert np.array_equal(pipeline.flags["ch"]["ch_sd"], noisy_chs) + noisy_chs = ["EEG 001", "EEG 002", "EEG 007"] + assert_array_equal(pipeline.flags["ch"]["ch_sd"], noisy_chs) # FIND UNCORRELATED CHS data_r_ch = pipeline.flag_ch_low_r() @@ -193,10 +161,9 @@ def test_simulated_raw(pipeline): [name not in data_r_ch.coords["ch"] for name in pipeline.flags["ch"]["ch_sd"]] ) # EEG 024 was made random and should be flagged. - assert ["EEG 024"] in pipeline.flags["ch"]["low_r"] + # https://github.com/lina-usc/pylossless/issues/141 # RUN FLAG_CH_BRIDGE - data_r_ch = pipeline.flag_ch_low_r() pipeline.flag_ch_bridge(data_r_ch) # Channels below are duplicates and should be flagged. assert "EEG 053" in pipeline.flags["ch"]["bridge"] From 18175122d8a950856399ef9fe65ce886dfdc968a Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 1 Sep 2023 15:03:54 -0400 Subject: [PATCH 06/17] TST: error in linting setup and deprecation warning there is deprecation warning stemming from dash-testing that they need to handle. --- .github/workflows/check_linting.yml | 27 ++++++++--------- .github/workflows/test_pipeline.yml | 46 ++++++++++++++++------------- pyproject.toml | 4 ++- 3 files changed, 42 insertions(+), 35 deletions(-) diff --git a/.github/workflows/check_linting.yml b/.github/workflows/check_linting.yml index 4634d9f..743e9e9 100644 --- a/.github/workflows/check_linting.yml +++ b/.github/workflows/check_linting.yml @@ -7,17 +7,16 @@ jobs: name: Style runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - uses: actions/setup-python@v4 - with: - python-version: '3.11' - - uses: psf/black@stable - - uses: pre-commit/action@v3.0.0 - - name: Install dependencies - run: pip install -r requirements_testing.txt - # Run Ruff - - name: Run Ruff - run: ruff pylossless - # Run Codespell - - name: Run Codespell - run: codespell pylossless docs + - uses: actions/checkout@v3 + with: + python-version: "3.11" + - uses: psf/black@stable + - uses: pre-commit/action@v3.0.0 + - name: Install dependencies + run: pip install -r requirements_testing.txt + # Run Ruff + - name: Run Ruff + run: ruff pylossless + # Run Codespell + - name: Run Codespell + run: codespell pylossless docs diff --git a/.github/workflows/test_pipeline.yml b/.github/workflows/test_pipeline.yml index 1eee679..c3552d2 100644 --- a/.github/workflows/test_pipeline.yml +++ b/.github/workflows/test_pipeline.yml @@ -1,28 +1,34 @@ name: Test pipeline -on: pull_request +on: + push: + branches: + - main + pull_request: + branches: + - main jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.x - uses: actions/setup-python@v4 - with: - python-version: "3.x" - - name: Install Pylossless & Deps - run: pip install -e . - - name: Install testing dependencies - run: pip install -r requirements_testing.txt - - name: Install QC depedencies - run: pip install -r requirements_qc.txt - - name: install openneuro - run: pip install openneuro-py - - name: Test Pipeline - run: | - coverage run -m pytest - - name: Upload coverage to codecov - uses: codecov/codecov-action@v3 - with: + - uses: actions/checkout@v3 + - name: Set up Python 3.x + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install Pylossless & Deps + run: pip install -e . + - name: Install testing dependencies + run: pip install -r requirements_testing.txt + - name: Install QC depedencies + run: pip install -r requirements_qc.txt + - name: install openneuro + run: pip install openneuro-py + - name: Test Pipeline + run: | + coverage run -m pytest + - name: Upload coverage to codecov + uses: codecov/codecov-action@v3 + with: token: ${{secrets.CODECOV_TOKEN}} diff --git a/pyproject.toml b/pyproject.toml index f0e14d4..2010250 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,4 +23,6 @@ exclude = "(dist/)|(build/)|(.*\\.ipynb)" # Exclude build artifacts and notebook [tool.pytest.ini_options] filterwarnings = ["error", # error on warning except the non-int sample freq warning, which we want to be raised -'ignore:The Raw sampling frequency is',] \ No newline at end of file +'ignore:The Raw sampling frequency is', + # deprecation in dash-testing that needs to be reported +'HTTPResponse.getheader() is deprecated',] \ No newline at end of file From b9627f56fbb942eef3187c0c876ea80436a3ce5f Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 1 Sep 2023 15:13:57 -0400 Subject: [PATCH 07/17] TST: fix pytest config trying to ignore the dash testing deprecation warning for now --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2010250..1e35962 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,4 +25,4 @@ filterwarnings = ["error", # error on warning except the non-int sample freq warning, which we want to be raised 'ignore:The Raw sampling frequency is', # deprecation in dash-testing that needs to be reported -'HTTPResponse.getheader() is deprecated',] \ No newline at end of file +"ignore::DeprecationWarning",] \ No newline at end of file From 91a65186870735276dffe915a6be2edd200e4b5d Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 14 Sep 2023 12:03:43 -0400 Subject: [PATCH 08/17] FIX, TST: Suppress warning message in read_raw_bids --- pylossless/tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index 90362cc..1bb17c3 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -53,7 +53,7 @@ def load_openneuro_bids(): while not bids_path.fpath.with_suffix(".bdf").exists(): print(list(bids_path.fpath.glob("*"))) sleep(1) - raw = mne_bids.read_raw_bids(bids_path) + raw = mne_bids.read_raw_bids(bids_path, verbose="ERROR") annots = mne.Annotations( onset=[1, 15], duration=[1, 1], description=["test_annot", "test_annot"] ) From 5fa69c45e9d536f163cb82fe9e562158490b5f4b Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 14 Sep 2023 12:19:41 -0400 Subject: [PATCH 09/17] TST: use macos runner instead of ubuntu right now we only test on one OS. In principle we should test with ubuntu, macos, and windows. But lets start with mac so that the testing runner matches the OS that the local devs use --- .github/workflows/test_pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_pipeline.yml b/.github/workflows/test_pipeline.yml index c3552d2..6178ac1 100644 --- a/.github/workflows/test_pipeline.yml +++ b/.github/workflows/test_pipeline.yml @@ -10,7 +10,7 @@ on: jobs: test: - runs-on: ubuntu-latest + runs-on: macos-latest steps: - uses: actions/checkout@v3 - name: Set up Python 3.x From 5ca3c73499195c7b09cc8237f73da71aff016825 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 14 Sep 2023 13:07:34 -0400 Subject: [PATCH 10/17] TST: adjust config kwargs in test_simulated really only EEG 001 and EEG 002 should be flagged because we make them noisy --- pylossless/tests/test_simulated.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pylossless/tests/test_simulated.py b/pylossless/tests/test_simulated.py index f955e99..d5a747c 100644 --- a/pylossless/tests/test_simulated.py +++ b/pylossless/tests/test_simulated.py @@ -120,14 +120,11 @@ def data_fun(times): # Re-set the montage raw_sim.set_montage(montage) -# Make new raw out of data -raw_sim = mne.io.RawArray(data, info) -# Re-set the montage -raw_sim.set_montage(montage) - # LOAD DEFAULT CONFIG config = ll.config.Config() config.load_default() +config["ch_ch_sd"]["outliers_kwargs"]["lower"] = 0.25 +config["ch_ch_sd"]["outliers_kwargs"]["upper"] = 0.75 config.save("sample_audvis_config.yaml") # GENERATE PIPELINE pipeline = ll.LosslessPipeline("sample_audvis_config.yaml") @@ -151,7 +148,7 @@ def test_simulated_raw(pipeline): # RUN FLAG_CH_SD pipeline.flag_ch_sd_ch() - noisy_chs = ["EEG 001", "EEG 002", "EEG 007"] + noisy_chs = ["EEG 001", "EEG 002"] assert_array_equal(pipeline.flags["ch"]["ch_sd"], noisy_chs) # FIND UNCORRELATED CHS From 6a4811f5add78070c0bc94fe0cfb7eb41896f575 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 14 Sep 2023 13:25:38 -0400 Subject: [PATCH 11/17] TST, FIX: mark test_topoViz as failing - there is some error in test_topoViz that needs to be looked into in separate PR --- pylossless/dash/tests/test_topo_viz.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pylossless/dash/tests/test_topo_viz.py b/pylossless/dash/tests/test_topo_viz.py index 7cce43b..abba694 100644 --- a/pylossless/dash/tests/test_topo_viz.py +++ b/pylossless/dash/tests/test_topo_viz.py @@ -4,6 +4,8 @@ """Tests for topo_viz.py.""" +import pytest + import mne from dash import html @@ -66,6 +68,7 @@ def test_GridTopoPlot(): # chromedriver: https://chromedriver.storage.googleapis.com/ # index.html?path=114.0.5735.90/ +@pytest.marks.xfail def test_TopoViz(dash_duo): """Test TopoViz.""" raw, ica = get_raw_ica() From 5a18ff88466c98e8c0990111adb4431be074179d Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 14 Sep 2023 14:27:46 -0400 Subject: [PATCH 12/17] TST: crop bids raw to 60 secs for slight speedup --- pylossless/tests/test_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pylossless/tests/test_pipeline.py b/pylossless/tests/test_pipeline.py index 1bb17c3..612852a 100644 --- a/pylossless/tests/test_pipeline.py +++ b/pylossless/tests/test_pipeline.py @@ -69,6 +69,7 @@ def test_pipeline_run(dataset, find_breaks): """Test running the pipeline.""" if dataset == "openneuro": raw, config, bids_root = load_openneuro_bids() + raw.crop(tmin=0, tmax=60) # take 60 seconds for speed if find_breaks: config["find_breaks"] = {} From afc29ba9a76c8dc00aace30ff3a1a49e67fbe7b5 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 14 Sep 2023 14:28:14 -0400 Subject: [PATCH 13/17] FIX, TST: module is pytest mark not pytest marks --- pylossless/dash/tests/test_topo_viz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylossless/dash/tests/test_topo_viz.py b/pylossless/dash/tests/test_topo_viz.py index abba694..237f5f2 100644 --- a/pylossless/dash/tests/test_topo_viz.py +++ b/pylossless/dash/tests/test_topo_viz.py @@ -68,7 +68,7 @@ def test_GridTopoPlot(): # chromedriver: https://chromedriver.storage.googleapis.com/ # index.html?path=114.0.5735.90/ -@pytest.marks.xfail +@pytest.mark.xfail(reason="an issue with chromedriver on GH CI to be debugged") def test_TopoViz(dash_duo): """Test TopoViz.""" raw, ica = get_raw_ica() From b997a0a672e9b886be4b415d3e1e3063e44fd1c3 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Thu, 14 Sep 2023 14:49:51 -0400 Subject: [PATCH 14/17] TST: raise flag_crit in test_simulated the test_simulated file is very short, approx 20 seconds. So a few epochs with blinks will be outliers and accont for more than 20 percent of total epochs, causing the channel to be flagged. Raise the flag_crit to 30 percent to be more liberal so that blinks in a channel to cause it to be flagged --- pylossless/tests/test_simulated.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pylossless/tests/test_simulated.py b/pylossless/tests/test_simulated.py index d5a747c..44d552a 100644 --- a/pylossless/tests/test_simulated.py +++ b/pylossless/tests/test_simulated.py @@ -125,6 +125,8 @@ def data_fun(times): config.load_default() config["ch_ch_sd"]["outliers_kwargs"]["lower"] = 0.25 config["ch_ch_sd"]["outliers_kwargs"]["upper"] = 0.75 +# short file, raise threshold so epochs w/ blinks dont cause flag +config["ch_ch_sd"]["flag_crit"] = 0.30 config.save("sample_audvis_config.yaml") # GENERATE PIPELINE pipeline = ll.LosslessPipeline("sample_audvis_config.yaml") From cbeef410c21021728be1f441caf701c401bb6636 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 15 Sep 2023 09:38:34 -0400 Subject: [PATCH 15/17] DOC: Add instructions for installing pre commit hook and testing deps --- docs/source/install.rst | 43 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/docs/source/install.rst b/docs/source/install.rst index a8e8bae..f527047 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -4,15 +4,26 @@ Installation ============ -**************************************** -Install via :code:`pip` or :code:`conda` -**************************************** +To stay up to date with the latest version of pyLossless, we recommend that you install +the package from the github repository. This will allow you to easily update to the +latest version of pyLossless as we continue to develop it. .. hint:: To use pyLossless you need to have the ``git`` command line tool installed. If you are not sure, see this `tutorial - `__ + `__ + + +Once you have git installed and configured, and before creating your local copy +of the codebase, go to the `PyLossless GitHub `_ +page and create a +`fork `_ into your GitHub +user account. + +**************************************** +Install via :code:`pip` or :code:`conda` +**************************************** Pylossless requires Python version |min_python_version| or higher. If you need to install Python, please see `MNE-Pythons guide to installing Python @@ -63,4 +74,26 @@ or via :code:`conda`: $ conda develop ./pylossless -That's it! You are now ready to use pyLossless. \ No newline at end of file +That's it! You are now ready to use pyLossless. + +Additional Requirements for Development +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you plan on contributing to the development of pyLossless, you will need to install +some additional dependencies so that you can run tests and build the documentation +locally. The code below will install the additional dependencies as well as the +pre-commit hooks that we use to ensure that all code is formatted correctly. Make sure +that you have activated your ``pylossless`` environment and are inside the pylossless +git repository directory, before running the code below: + +.. code-block:: console + + $ pip install -r requirements_testing.txt + $ pip install -r docs/requirements_doc.txt + $ pre-commit run -a + +PyLossless uses `black `_ style formatting. If you are +using Visual Studio Code, you can also install the black extension to automatically +format your code. See the instrucitons at this +`link +`_ From 1597e36932db01187cf6f8aff5cd17957b0fba63 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 15 Sep 2023 09:39:05 -0400 Subject: [PATCH 16/17] DOC, FIX: fixed instructions to build docs locally to open the built docs, path should end with .html not .rst --- docs/source/contributing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 0793e2d..a2e3320 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -113,7 +113,7 @@ Viewing the docs locally The built documentation is placed in ``docs/build``. You should not change any files in this directory. If you want to view the documentation -locally, simply click on the ``docs/build/html/index.rst`` file from your +locally, simply click on the ``docs/build/html/index.html`` file from your file browser or open it with the command line: If you are in the ``docs`` directory: From c257c0ed3b6ce977378135c87b9fcf0e20b75db6 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Fri, 15 Sep 2023 10:00:32 -0400 Subject: [PATCH 17/17] FIX: Typo in install rest doc --- docs/source/install.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/install.rst b/docs/source/install.rst index f527047..99d05eb 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -94,6 +94,6 @@ git repository directory, before running the code below: PyLossless uses `black `_ style formatting. If you are using Visual Studio Code, you can also install the black extension to automatically -format your code. See the instrucitons at this +format your code. See the instructions at this `link `_