Skip to content

Commit

Permalink
BUG: Move defusedxml to optional dependencies (#12264)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Dec 5, 2023
1 parent b876edb commit 2a87af7
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 23 deletions.
1 change: 1 addition & 0 deletions doc/changes/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Version 1.6.1 (unreleased)
--------------------------

- Fix bug with type hints in :func:`mne.io.read_raw_neuralynx` (:gh:`12236` by `Richard Höchenberger`_)
- ``defusedxml`` is now an optional (rather than required) dependency and needed when reading EGI-MFF data, NEDF data, and BrainVision montages (:gh:`12264` by `Eric Larson`_)

.. _changes_1_6_0:

Expand Down
11 changes: 5 additions & 6 deletions mne/channels/_dig_montage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# Copyright the MNE-Python contributors.

import numpy as np
from defusedxml import ElementTree

from ..utils import Bunch, _check_fname, warn
from ..utils import Bunch, _check_fname, _soft_import, warn


def _read_dig_montage_egi(
Expand All @@ -28,8 +27,8 @@ def _read_dig_montage_egi(
"hsp, hpi, elp, point_names, fif must all be " "None if egi is not None"
)
_check_fname(fname, overwrite="read", must_exist=True)

root = ElementTree.parse(fname).getroot()
defusedxml = _soft_import("defusedxml", "reading EGI montages")
root = defusedxml.ElementTree.parse(fname).getroot()
ns = root.tag[root.tag.index("{") : root.tag.index("}") + 1]
sensors = root.find("%ssensorLayout/%ssensors" % (ns, ns))
fids = dict()
Expand Down Expand Up @@ -76,8 +75,8 @@ def _read_dig_montage_egi(

def _parse_brainvision_dig_montage(fname, scale):
FID_NAME_MAP = {"Nasion": "nasion", "RPA": "rpa", "LPA": "lpa"}

root = ElementTree.parse(fname).getroot()
defusedxml = _soft_import("defusedxml", "reading BrainVision montages")
root = defusedxml.ElementTree.parse(fname).getroot()
sensors = root.find("CapTrakElectrodeList")

fids, dig_ch_pos = dict(), dict()
Expand Down
6 changes: 3 additions & 3 deletions mne/channels/_standard_montage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from functools import partial

import numpy as np
from defusedxml import ElementTree

from .._freesurfer import get_mni_fiducials
from ..transforms import _sph_to_cart
from ..utils import _pl, warn
from ..utils import _pl, _soft_import, warn
from . import __file__ as _CHANNELS_INIT_FILE
from .montage import make_dig_montage

Expand Down Expand Up @@ -344,7 +343,8 @@ def _read_brainvision(fname, head_size):
# standard electrode positions: X-axis from T7 to T8, Y-axis from Oz to
# Fpz, Z-axis orthogonal from XY-plane through Cz, fit to a sphere if
# idealized (when radius=1), specified in millimeters
root = ElementTree.parse(fname).getroot()
defusedxml = _soft_import("defusedxml", "reading BrainVision montages")
root = defusedxml.ElementTree.parse(fname).getroot()
ch_names = [s.text for s in root.findall("./Electrode/Name")]
theta = [float(s.text) for s in root.findall("./Electrode/Theta")]
pol = np.deg2rad(np.array(theta))
Expand Down
7 changes: 5 additions & 2 deletions mne/channels/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ def test_documented():
)
def test_montage_readers(reader, file_content, expected_dig, ext, warning, tmp_path):
"""Test that we have an equivalent of read_montage for all file formats."""
if file_content.startswith("<?xml"):
pytest.importorskip("defusedxml")
fname = tmp_path / f"test.{ext}"
with open(fname, "w") as fid:
fid.write(file_content)
Expand Down Expand Up @@ -1067,6 +1069,7 @@ def test_fif_dig_montage(tmp_path):
@testing.requires_testing_data
def test_egi_dig_montage(tmp_path):
"""Test EGI MFF XML dig montage support."""
pytest.importorskip("defusedxml")
dig_montage = read_dig_egi(egi_dig_montage_fname)
fid, coord = _get_fid_coords(dig_montage.dig)

Expand Down Expand Up @@ -1123,6 +1126,7 @@ def _pop_montage(dig_montage, ch_name):
@testing.requires_testing_data
def test_read_dig_captrak(tmp_path):
"""Test reading a captrak montage file."""
pytest.importorskip("defusedxml")
EXPECTED_CH_NAMES_OLD = [
"AF3",
"AF4",
Expand Down Expand Up @@ -1933,13 +1937,12 @@ def test_get_builtin_montages():
def test_plot_montage():
"""Test plotting montage."""
# gh-8025
pytest.importorskip("defusedxml")
montage = read_dig_captrak(bvct_dig_montage_fname)
montage.plot()
plt.close("all")

f, ax = plt.subplots(1, 1)
montage.plot(axes=ax)
plt.close("all")

with pytest.raises(TypeError, match="must be an instance of Axes"):
montage.plot(axes=101)
Expand Down
2 changes: 2 additions & 0 deletions mne/export/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def test_export_epochs_eeglab(tmp_path, preload):
def test_export_evokeds_to_mff(tmp_path, fmt, do_history):
"""Test exporting evoked dataset to MFF."""
pytest.importorskip("mffpy", "0.5.7")
pytest.importorskip("defusedxml")
evoked = read_evokeds_mff(egi_evoked_fname)
export_fname = tmp_path / "evoked.mff"
history = [
Expand Down Expand Up @@ -515,6 +516,7 @@ def test_export_evokeds_to_mff(tmp_path, fmt, do_history):
def test_export_to_mff_no_device():
"""Test no device type throws ValueError."""
pytest.importorskip("mffpy", "0.5.7")
pytest.importorskip("defusedxml")
evoked = read_evokeds_mff(egi_evoked_fname, condition="Category 1")
evoked.info["device_info"] = None
with pytest.raises(ValueError, match="No device type."):
Expand Down
9 changes: 7 additions & 2 deletions mne/io/egi/egimff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path

import numpy as np
from defusedxml.minidom import parse

from ..._fiff.constants import FIFF
from ..._fiff.meas_info import _empty_info, _ensure_meas_date_none_or_dt, create_info
Expand All @@ -19,7 +18,7 @@
from ...annotations import Annotations
from ...channels.montage import make_dig_montage
from ...evoked import EvokedArray
from ...utils import _check_fname, _check_option, logger, verbose, warn
from ...utils import _check_fname, _check_option, _soft_import, logger, verbose, warn
from ..base import BaseRaw
from .events import _combine_triggers, _read_events
from .general import (
Expand All @@ -36,6 +35,9 @@

def _read_mff_header(filepath):
"""Read mff header."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

all_files = _get_signalfname(filepath)
eeg_file = all_files["EEG"]["signal"]
eeg_info_file = all_files["EEG"]["info"]
Expand Down Expand Up @@ -289,6 +291,9 @@ def _get_eeg_calibration_info(filepath, egi_info):

def _read_locs(filepath, egi_info, channel_naming):
"""Read channel locations."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

fname = op.join(filepath, "coordinates.xml")
if not op.exists(fname):
logger.warn("File coordinates.xml not found, not setting channel locations")
Expand Down
6 changes: 3 additions & 3 deletions mne/io/egi/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from os.path import basename, join, splitext

import numpy as np
from defusedxml.ElementTree import parse

from ...utils import logger
from ...utils import _soft_import, logger


def _read_events(input_fname, info):
Expand Down Expand Up @@ -82,7 +81,8 @@ def _read_mff_events(filename, sfreq):

def _parse_xml(xml_file):
"""Parse XML file."""
xml = parse(xml_file)
defusedxml = _soft_import("defusedxml", "reading EGI MFF data")
xml = defusedxml.ElementTree.parse(xml_file)
root = xml.getroot()
return _xml2list(root)

Expand Down
15 changes: 13 additions & 2 deletions mne/io/egi/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import re

import numpy as np
from defusedxml.minidom import parse

from ...utils import _pl
from ...utils import _pl, _soft_import


def _extract(tags, filepath=None, obj=None):
"""Extract info from XML."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

if obj is not None:
fileobj = obj
elif filepath is not None:
Expand All @@ -30,6 +32,9 @@ def _extract(tags, filepath=None, obj=None):

def _get_gains(filepath):
"""Parse gains."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

file_obj = parse(filepath)
objects = file_obj.getElementsByTagName("calibration")
gains = dict()
Expand All @@ -46,6 +51,9 @@ def _get_gains(filepath):

def _get_ep_info(filepath):
"""Get epoch info."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

epochfile = filepath + "/epochs.xml"
epochlist = parse(epochfile)
epochs = epochlist.getElementsByTagName("epoch")
Expand Down Expand Up @@ -123,6 +131,9 @@ def _get_blocks(filepath):

def _get_signalfname(filepath):
"""Get filenames."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

listfiles = os.listdir(filepath)
binfiles = list(
f for f in listfiles if "signal" in f and f[-4:] == ".bin" and f[0] != "."
Expand Down
11 changes: 11 additions & 0 deletions mne/io/egi/tests/test_egi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
)
def test_egi_mff_pause(fname, skip_times, event_times):
"""Test EGI MFF with pauses."""
pytest.importorskip("defusedxml")
if fname == egi_pause_w1337_fname:
# too slow to _test_raw_reader
raw = read_raw_egi(fname).load_data()
Expand Down Expand Up @@ -129,6 +130,7 @@ def test_egi_mff_pause(fname, skip_times, event_times):
)
def test_egi_mff_pause_chunks(fname, tmp_path):
"""Test that on-demand of all short segments works (via I/O)."""
pytest.importorskip("defusedxml")
fname_temp = tmp_path / "test_raw.fif"
raw_data = read_raw_egi(fname, preload=True).get_data()
raw = read_raw_egi(fname)
Expand All @@ -142,6 +144,7 @@ def test_egi_mff_pause_chunks(fname, tmp_path):
@requires_testing_data
def test_io_egi_mff():
"""Test importing EGI MFF simple binary files."""
pytest.importorskip("defusedxml")
# want vars for n chans
n_ref = 1
n_eeg = 128
Expand Down Expand Up @@ -258,6 +261,7 @@ def test_io_egi():
@requires_testing_data
def test_io_egi_pns_mff(tmp_path):
"""Test importing EGI MFF with PNS data."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(egi_mff_pns_fname, include=None, preload=True, verbose="error")
assert "RawMff" in repr(raw)
pns_chans = pick_types(raw.info, ecg=True, bio=True, emg=True)
Expand Down Expand Up @@ -314,6 +318,7 @@ def test_io_egi_pns_mff(tmp_path):
@pytest.mark.parametrize("preload", (True, False))
def test_io_egi_pns_mff_bug(preload):
"""Test importing EGI MFF with PNS data (BUG)."""
pytest.importorskip("defusedxml")
egi_fname_mff = testing_path / "EGI" / "test_egi_pns_bug.mff"
with pytest.warns(RuntimeWarning, match="EGI PSG sample bug"):
raw = read_raw_egi(
Expand Down Expand Up @@ -356,6 +361,7 @@ def test_io_egi_pns_mff_bug(preload):
@requires_testing_data
def test_io_egi_crop_no_preload():
"""Test crop non-preloaded EGI MFF data (BUG)."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(egi_mff_fname, preload=False)
raw.crop(17.5, 20.5)
raw.load_data()
Expand Down Expand Up @@ -383,6 +389,8 @@ def test_io_egi_crop_no_preload():
def test_io_egi_evokeds_mff(idx, cond, tmax, signals, bads):
"""Test reading evoked MFF file."""
pytest.importorskip("mffpy", "0.5.7")

pytest.importorskip("defusedxml")
# expected n channels
n_eeg = 256
n_ref = 1
Expand Down Expand Up @@ -468,6 +476,7 @@ def test_read_evokeds_mff_bad_input():
@requires_testing_data
def test_egi_coord_frame():
"""Test that EGI coordinate frame is changed to head."""
pytest.importorskip("defusedxml")
info = read_raw_egi(egi_mff_fname).info
want_idents = (
FIFF.FIFFV_POINT_LPA,
Expand Down Expand Up @@ -505,6 +514,7 @@ def test_egi_coord_frame():
)
def test_meas_date(fname, timestamp, utc_offset):
"""Test meas date conversion."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(fname, verbose="warning")
dt = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f%z")
measdate = dt.astimezone(timezone.utc)
Expand All @@ -526,6 +536,7 @@ def test_meas_date(fname, timestamp, utc_offset):
)
def test_set_standard_montage_mff(fname, standard_montage):
"""Test setting a standard montage."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(fname, verbose="warning")
n_eeg = int(standard_montage.split("-")[-1])
n_dig = n_eeg + 3
Expand Down
6 changes: 3 additions & 3 deletions mne/io/nedf/nedf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from datetime import datetime, timezone

import numpy as np
from defusedxml import ElementTree

from ..._fiff.meas_info import create_info
from ..._fiff.utils import _mult_cal_one
from ...utils import _check_fname, verbose, warn
from ...utils import _check_fname, _soft_import, verbose, warn
from ..base import BaseRaw


Expand Down Expand Up @@ -52,6 +51,7 @@ def _parse_nedf_header(header):
n_samples : int
The number of data samples.
"""
defusedxml = _soft_import("defusedxml", "reading NEDF data")
info = {}
# nedf files have three accelerometer channels sampled at 100Hz followed
# by five EEG samples + TTL trigger sampled at 500Hz
Expand All @@ -69,7 +69,7 @@ def _parse_nedf_header(header):
headerend = header.find(b"\0")
if headerend == -1:
raise RuntimeError("End of header null not found")
headerxml = ElementTree.fromstring(header[:headerend])
headerxml = defusedxml.ElementTree.fromstring(header[:headerend])
nedfversion = headerxml.findtext("NEDFversion", "")
if nedfversion not in ["1.3", "1.4"]:
warn("NEDFversion unsupported, use with caution")
Expand Down
2 changes: 2 additions & 0 deletions mne/io/nedf/tests/test_nedf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
<STIMSettings/>
</nedf>\x00"""

pytest.importorskip("defusedxml")


@pytest.mark.parametrize("nacc", (0, 3))
def test_nedf_header_parser(nacc):
Expand Down
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ dependencies = [
"jinja2",
"importlib_resources>=5.10.2; python_version<'3.9'",
"lazy_loader>=0.3",
"defusedxml",
]

[project.optional-dependencies]
Expand All @@ -60,7 +59,8 @@ hdf5 = ["h5io", "pymatreader"]
full = [
"mne[hdf5]",
"qtpy",
"PyQt6",
"PyQt6!=6.6.1",
"PyQt6-Qt6!=6.6.1",
"pyobjc-framework-Cocoa>=5.2.0; platform_system=='Darwin'",
"sip",
"scikit-learn",
Expand Down Expand Up @@ -95,6 +95,12 @@ full = [
"darkdetect",
"qdarkstyle",
"threadpoolctl",
# duplicated in test_extra:
"eeglabio",
"EDFlib-Python",
"pybv",
"snirf",
"defusedxml",
]

# Dependencies for running the test infrastructure
Expand Down

0 comments on commit 2a87af7

Please sign in to comment.