Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Move defusedxml to optional dependencies #12264

Merged
merged 6 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Bugs
~~~~
- Allow :func:`mne.viz.plot_compare_evokeds` to plot eyetracking channels, and improve error handling (:gh:`12190` by `Scott Huberty`_)
- Fix bug with accessing the last data sample using ``raw[:, -1]`` where an empty array was returned (:gh:`12248` by `Eric Larson`_)
- ``defusedxml`` is now an optional (rather than required) dependency and needed when reading EGI-MFF data, NEDF data, and BrainVision montages (:gh:`12264` by `Eric Larson`_)
- Fix bug with type hints in :func:`mne.io.read_raw_neuralynx` (:gh:`12236` by `Richard Höchenberger`_)

API changes
Expand Down
11 changes: 5 additions & 6 deletions mne/channels/_dig_montage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# Copyright the MNE-Python contributors.

import numpy as np
from defusedxml import ElementTree

from ..utils import Bunch, _check_fname, warn
from ..utils import Bunch, _check_fname, _soft_import, warn


def _read_dig_montage_egi(
Expand All @@ -28,8 +27,8 @@ def _read_dig_montage_egi(
"hsp, hpi, elp, point_names, fif must all be " "None if egi is not None"
)
_check_fname(fname, overwrite="read", must_exist=True)

root = ElementTree.parse(fname).getroot()
defusedxml = _soft_import("defusedxml", "reading EGI montages")
root = defusedxml.ElementTree.parse(fname).getroot()
ns = root.tag[root.tag.index("{") : root.tag.index("}") + 1]
sensors = root.find("%ssensorLayout/%ssensors" % (ns, ns))
fids = dict()
Expand Down Expand Up @@ -76,8 +75,8 @@ def _read_dig_montage_egi(

def _parse_brainvision_dig_montage(fname, scale):
FID_NAME_MAP = {"Nasion": "nasion", "RPA": "rpa", "LPA": "lpa"}

root = ElementTree.parse(fname).getroot()
defusedxml = _soft_import("defusedxml", "reading BrainVision montages")
root = defusedxml.ElementTree.parse(fname).getroot()
sensors = root.find("CapTrakElectrodeList")

fids, dig_ch_pos = dict(), dict()
Expand Down
6 changes: 3 additions & 3 deletions mne/channels/_standard_montage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from functools import partial

import numpy as np
from defusedxml import ElementTree

from .._freesurfer import get_mni_fiducials
from ..transforms import _sph_to_cart
from ..utils import _pl, warn
from ..utils import _pl, _soft_import, warn
from . import __file__ as _CHANNELS_INIT_FILE
from .montage import make_dig_montage

Expand Down Expand Up @@ -344,7 +343,8 @@ def _read_brainvision(fname, head_size):
# standard electrode positions: X-axis from T7 to T8, Y-axis from Oz to
# Fpz, Z-axis orthogonal from XY-plane through Cz, fit to a sphere if
# idealized (when radius=1), specified in millimeters
root = ElementTree.parse(fname).getroot()
defusedxml = _soft_import("defusedxml", "reading BrainVision montages")
root = defusedxml.ElementTree.parse(fname).getroot()
ch_names = [s.text for s in root.findall("./Electrode/Name")]
theta = [float(s.text) for s in root.findall("./Electrode/Theta")]
pol = np.deg2rad(np.array(theta))
Expand Down
7 changes: 5 additions & 2 deletions mne/channels/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ def test_documented():
)
def test_montage_readers(reader, file_content, expected_dig, ext, warning, tmp_path):
"""Test that we have an equivalent of read_montage for all file formats."""
if file_content.startswith("<?xml"):
pytest.importorskip("defusedxml")
fname = tmp_path / f"test.{ext}"
with open(fname, "w") as fid:
fid.write(file_content)
Expand Down Expand Up @@ -1067,6 +1069,7 @@ def test_fif_dig_montage(tmp_path):
@testing.requires_testing_data
def test_egi_dig_montage(tmp_path):
"""Test EGI MFF XML dig montage support."""
pytest.importorskip("defusedxml")
dig_montage = read_dig_egi(egi_dig_montage_fname)
fid, coord = _get_fid_coords(dig_montage.dig)

Expand Down Expand Up @@ -1123,6 +1126,7 @@ def _pop_montage(dig_montage, ch_name):
@testing.requires_testing_data
def test_read_dig_captrak(tmp_path):
"""Test reading a captrak montage file."""
pytest.importorskip("defusedxml")
EXPECTED_CH_NAMES_OLD = [
"AF3",
"AF4",
Expand Down Expand Up @@ -1933,13 +1937,12 @@ def test_get_builtin_montages():
def test_plot_montage():
"""Test plotting montage."""
# gh-8025
pytest.importorskip("defusedxml")
montage = read_dig_captrak(bvct_dig_montage_fname)
montage.plot()
plt.close("all")

f, ax = plt.subplots(1, 1)
montage.plot(axes=ax)
plt.close("all")

with pytest.raises(TypeError, match="must be an instance of Axes"):
montage.plot(axes=101)
Expand Down
2 changes: 2 additions & 0 deletions mne/export/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def test_export_epochs_eeglab(tmp_path, preload):
def test_export_evokeds_to_mff(tmp_path, fmt, do_history):
"""Test exporting evoked dataset to MFF."""
pytest.importorskip("mffpy", "0.5.7")
pytest.importorskip("defusedxml")
evoked = read_evokeds_mff(egi_evoked_fname)
export_fname = tmp_path / "evoked.mff"
history = [
Expand Down Expand Up @@ -486,6 +487,7 @@ def test_export_evokeds_to_mff(tmp_path, fmt, do_history):
def test_export_to_mff_no_device():
"""Test no device type throws ValueError."""
pytest.importorskip("mffpy", "0.5.7")
pytest.importorskip("defusedxml")
evoked = read_evokeds_mff(egi_evoked_fname, condition="Category 1")
evoked.info["device_info"] = None
with pytest.raises(ValueError, match="No device type."):
Expand Down
9 changes: 7 additions & 2 deletions mne/io/egi/egimff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path

import numpy as np
from defusedxml.minidom import parse

from ..._fiff.constants import FIFF
from ..._fiff.meas_info import _empty_info, _ensure_meas_date_none_or_dt, create_info
Expand All @@ -19,7 +18,7 @@
from ...annotations import Annotations
from ...channels.montage import make_dig_montage
from ...evoked import EvokedArray
from ...utils import _check_fname, _check_option, logger, verbose, warn
from ...utils import _check_fname, _check_option, _soft_import, logger, verbose, warn
from ..base import BaseRaw
from .events import _combine_triggers, _read_events
from .general import (
Expand All @@ -36,6 +35,9 @@

def _read_mff_header(filepath):
"""Read mff header."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

all_files = _get_signalfname(filepath)
eeg_file = all_files["EEG"]["signal"]
eeg_info_file = all_files["EEG"]["info"]
Expand Down Expand Up @@ -289,6 +291,9 @@ def _get_eeg_calibration_info(filepath, egi_info):

def _read_locs(filepath, egi_info, channel_naming):
"""Read channel locations."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

fname = op.join(filepath, "coordinates.xml")
if not op.exists(fname):
logger.warn("File coordinates.xml not found, not setting channel locations")
Expand Down
6 changes: 3 additions & 3 deletions mne/io/egi/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from os.path import basename, join, splitext

import numpy as np
from defusedxml.ElementTree import parse

from ...utils import logger
from ...utils import _soft_import, logger


def _read_events(input_fname, info):
Expand Down Expand Up @@ -82,7 +81,8 @@ def _read_mff_events(filename, sfreq):

def _parse_xml(xml_file):
"""Parse XML file."""
xml = parse(xml_file)
defusedxml = _soft_import("defusedxml", "reading EGI MFF data")
xml = defusedxml.ElementTree.parse(xml_file)
root = xml.getroot()
return _xml2list(root)

Expand Down
15 changes: 13 additions & 2 deletions mne/io/egi/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import re

import numpy as np
from defusedxml.minidom import parse

from ...utils import _pl
from ...utils import _pl, _soft_import


def _extract(tags, filepath=None, obj=None):
"""Extract info from XML."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

if obj is not None:
fileobj = obj
elif filepath is not None:
Expand All @@ -30,6 +32,9 @@ def _extract(tags, filepath=None, obj=None):

def _get_gains(filepath):
"""Parse gains."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

file_obj = parse(filepath)
objects = file_obj.getElementsByTagName("calibration")
gains = dict()
Expand All @@ -46,6 +51,9 @@ def _get_gains(filepath):

def _get_ep_info(filepath):
"""Get epoch info."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

epochfile = filepath + "/epochs.xml"
epochlist = parse(epochfile)
epochs = epochlist.getElementsByTagName("epoch")
Expand Down Expand Up @@ -123,6 +131,9 @@ def _get_blocks(filepath):

def _get_signalfname(filepath):
"""Get filenames."""
_soft_import("defusedxml", "reading EGI MFF data")
from defusedxml.minidom import parse

listfiles = os.listdir(filepath)
binfiles = list(
f for f in listfiles if "signal" in f and f[-4:] == ".bin" and f[0] != "."
Expand Down
11 changes: 11 additions & 0 deletions mne/io/egi/tests/test_egi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
)
def test_egi_mff_pause(fname, skip_times, event_times):
"""Test EGI MFF with pauses."""
pytest.importorskip("defusedxml")
if fname == egi_pause_w1337_fname:
# too slow to _test_raw_reader
raw = read_raw_egi(fname).load_data()
Expand Down Expand Up @@ -129,6 +130,7 @@ def test_egi_mff_pause(fname, skip_times, event_times):
)
def test_egi_mff_pause_chunks(fname, tmp_path):
"""Test that on-demand of all short segments works (via I/O)."""
pytest.importorskip("defusedxml")
fname_temp = tmp_path / "test_raw.fif"
raw_data = read_raw_egi(fname, preload=True).get_data()
raw = read_raw_egi(fname)
Expand All @@ -142,6 +144,7 @@ def test_egi_mff_pause_chunks(fname, tmp_path):
@requires_testing_data
def test_io_egi_mff():
"""Test importing EGI MFF simple binary files."""
pytest.importorskip("defusedxml")
# want vars for n chans
n_ref = 1
n_eeg = 128
Expand Down Expand Up @@ -258,6 +261,7 @@ def test_io_egi():
@requires_testing_data
def test_io_egi_pns_mff(tmp_path):
"""Test importing EGI MFF with PNS data."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(egi_mff_pns_fname, include=None, preload=True, verbose="error")
assert "RawMff" in repr(raw)
pns_chans = pick_types(raw.info, ecg=True, bio=True, emg=True)
Expand Down Expand Up @@ -314,6 +318,7 @@ def test_io_egi_pns_mff(tmp_path):
@pytest.mark.parametrize("preload", (True, False))
def test_io_egi_pns_mff_bug(preload):
"""Test importing EGI MFF with PNS data (BUG)."""
pytest.importorskip("defusedxml")
egi_fname_mff = testing_path / "EGI" / "test_egi_pns_bug.mff"
with pytest.warns(RuntimeWarning, match="EGI PSG sample bug"):
raw = read_raw_egi(
Expand Down Expand Up @@ -356,6 +361,7 @@ def test_io_egi_pns_mff_bug(preload):
@requires_testing_data
def test_io_egi_crop_no_preload():
"""Test crop non-preloaded EGI MFF data (BUG)."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(egi_mff_fname, preload=False)
raw.crop(17.5, 20.5)
raw.load_data()
Expand Down Expand Up @@ -383,6 +389,8 @@ def test_io_egi_crop_no_preload():
def test_io_egi_evokeds_mff(idx, cond, tmax, signals, bads):
"""Test reading evoked MFF file."""
pytest.importorskip("mffpy", "0.5.7")

pytest.importorskip("defusedxml")
# expected n channels
n_eeg = 256
n_ref = 1
Expand Down Expand Up @@ -468,6 +476,7 @@ def test_read_evokeds_mff_bad_input():
@requires_testing_data
def test_egi_coord_frame():
"""Test that EGI coordinate frame is changed to head."""
pytest.importorskip("defusedxml")
info = read_raw_egi(egi_mff_fname).info
want_idents = (
FIFF.FIFFV_POINT_LPA,
Expand Down Expand Up @@ -505,6 +514,7 @@ def test_egi_coord_frame():
)
def test_meas_date(fname, timestamp, utc_offset):
"""Test meas date conversion."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(fname, verbose="warning")
dt = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f%z")
measdate = dt.astimezone(timezone.utc)
Expand All @@ -526,6 +536,7 @@ def test_meas_date(fname, timestamp, utc_offset):
)
def test_set_standard_montage_mff(fname, standard_montage):
"""Test setting a standard montage."""
pytest.importorskip("defusedxml")
raw = read_raw_egi(fname, verbose="warning")
n_eeg = int(standard_montage.split("-")[-1])
n_dig = n_eeg + 3
Expand Down
6 changes: 3 additions & 3 deletions mne/io/nedf/nedf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from datetime import datetime, timezone

import numpy as np
from defusedxml import ElementTree

from ..._fiff.meas_info import create_info
from ..._fiff.utils import _mult_cal_one
from ...utils import _check_fname, verbose, warn
from ...utils import _check_fname, _soft_import, verbose, warn
from ..base import BaseRaw


Expand Down Expand Up @@ -52,6 +51,7 @@ def _parse_nedf_header(header):
n_samples : int
The number of data samples.
"""
defusedxml = _soft_import("defusedxml", "reading NEDF data")
info = {}
# nedf files have three accelerometer channels sampled at 100Hz followed
# by five EEG samples + TTL trigger sampled at 500Hz
Expand All @@ -69,7 +69,7 @@ def _parse_nedf_header(header):
headerend = header.find(b"\0")
if headerend == -1:
raise RuntimeError("End of header null not found")
headerxml = ElementTree.fromstring(header[:headerend])
headerxml = defusedxml.ElementTree.fromstring(header[:headerend])
nedfversion = headerxml.findtext("NEDFversion", "")
if nedfversion not in ["1.3", "1.4"]:
warn("NEDFversion unsupported, use with caution")
Expand Down
2 changes: 2 additions & 0 deletions mne/io/nedf/tests/test_nedf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
<STIMSettings/>
</nedf>\x00"""

pytest.importorskip("defusedxml")


@pytest.mark.parametrize("nacc", (0, 3))
def test_nedf_header_parser(nacc):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ dependencies = [
"packaging",
"jinja2",
"lazy_loader>=0.3",
"defusedxml",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -100,6 +99,7 @@ full = [
"edfio>=0.2.1",
"pybv",
"snirf",
"defusedxml",
]

# Dependencies for running the test infrastructure
Expand Down