Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing topomaps. #130

Merged
merged 10 commits into from
Oct 26, 2023
52 changes: 35 additions & 17 deletions pylossless/dash/topo_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
yaxis.update({"scaleanchor": "x", "scaleratio": 1})


def pick_montage(montage, ch_names):
"""Pick a subset of channels from a montage."""
digs = montage.remove_fiducials().dig
digs = [dig for dig, ch_name in zip(digs, montage.ch_names) if ch_name in ch_names]
return mne.channels.DigMontage(dig=digs, ch_names=ch_names)

Check warning on line 40 in pylossless/dash/topo_viz.py

View check run for this annotation

Codecov / codecov/patch

pylossless/dash/topo_viz.py#L38-L40

Added lines #L38 - L40 were not covered by tests


class TopoPlot: # TODO: Fix/finish doc comments for this class.
"""Representation of a classic EEG topographic map as a plotly figure."""

Expand All @@ -47,7 +54,7 @@
res=64,
width=None,
height=None,
cmap="RdBu_r",
cmap=None,
show_sensors=True,
colorbar=False,
):
Expand Down Expand Up @@ -162,9 +169,11 @@
self.info = create_info(names, sfreq=256, ch_types="eeg")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# To update self.info with channels positions
RawArray(
np.zeros((len(names), 1)), self.info, copy=None, verbose=False
).set_montage(self.montage)
assert np.all(np.array(names) == np.array(self.info.ch_names))
self.set_head_pos_contours()

# TODO: Finish/fix docstring
Expand Down Expand Up @@ -278,12 +287,24 @@
-------
A plotly.graph_objects.Figure object.
"""
from .utils import _setup_vmin_vmax

if self.__data is None:
return

data = np.array(list(self.__data.values()))
norm = min(np.array(data)) >= 0
vmin, vmax = _setup_vmin_vmax(data, None, None, norm)
if self.cmap is None:
cmap = "Reds" if norm else "RdBu_r"
else:
cmap = self.cmap

Check warning on line 301 in pylossless/dash/topo_viz.py

View check run for this annotation

Codecov / codecov/patch

pylossless/dash/topo_viz.py#L301

Added line #L301 was not covered by tests

heatmap_trace = go.Heatmap(
showscale=self.colorbar,
colorscale=self.cmap,
colorscale=cmap,
zmin=vmin,
zmax=vmax,
**self.get_heatmap_data(**kwargs)
)

Expand Down Expand Up @@ -341,9 +362,10 @@
See mne.channels.make_standard_montage(), and
mne.channels.get_builtin_montages() for more information
on making montage objects in MNE.
data : mne.preprocessing.ICA | None
The data to use for the topoplots. Can be an instance of
mne.preprocessing.ICA.
data : list | None
The data to use for the topoplots. Should be a list of
dictionaries, one per topomap. The dictionaries should
have the channel names as keys.
figure : plotly.graph_objects.Figure | None
Figure to use (if not None) for plotting.
color : str
Expand Down Expand Up @@ -635,13 +657,11 @@

# The indexing with ch_names is to ensure the order
# of the channels are compatible between plot_data and the montage
ch_names = [
ch_name
for ch_name in self.montage.ch_names
if ch_name in self.data.topo_values.columns
montage = pick_montage(self.montage, self.data.topo_values.columns)
ch_names = montage.ch_names
plot_data = [

Check warning on line 662 in pylossless/dash/topo_viz.py

View check run for this annotation

Codecov / codecov/patch

pylossless/dash/topo_viz.py#L660-L662

Added lines #L660 - L662 were not covered by tests
OrderedDict(self.data.topo_values.loc[title, ch_names]) for title in titles
]
plot_data = self.data.topo_values.loc[titles, ch_names]
plot_data = list(plot_data.T.to_dict().values())

if len(plot_data) < self.nb_sel_topo:
nb_missing_topo = self.nb_sel_topo - len(plot_data)
Expand All @@ -650,7 +670,7 @@
self.figure = GridTopoPlot(
rows=self.rows,
cols=self.cols,
montage=self.montage,
montage=montage,
data=plot_data,
color=colors,
res=self.res,
Expand Down Expand Up @@ -807,16 +827,14 @@
return None

data = TopoData(
[
dict(zip(montage.ch_names, component))
for component in ica.get_components().T
]
[dict(zip(ica.ch_names, component)) for component in ica.get_components().T]
)
data.topo_values.index = ica._ica_names

Check warning on line 832 in pylossless/dash/topo_viz.py

View check run for this annotation

Codecov / codecov/patch

pylossless/dash/topo_viz.py#L832

Added line #L832 was not covered by tests

if ic_labels:
self.head_contours_color = {
comp: ic_label_cmap[label] for comp, label in ic_labels.items()
}
data.topo_values.index = list(ic_labels.keys())
return data

def load_recording(self, montage, ica, ic_labels):
Expand Down
24 changes: 24 additions & 0 deletions pylossless/dash/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np


def _setup_vmin_vmax(data, vmin, vmax, norm=False):
"""Handle vmin and vmax parameters for visualizing topomaps.

This is a simplified copy of mne.viz.utils._setup_vmin_vmax.
https://github.com/mne-tools/mne-python/blob/main/mne/viz/utils.py

Notes
-----
For the normal use-case (when `vmin` and `vmax` are None), the parameter
`norm` drives the computation. When norm=False, data is supposed to come
from a mag and the output tuple (vmin, vmax) is symmetric range
(-x, x) where x is the max(abs(data)). When norm=True (a.k.a. data is the
L2 norm of a gradiometer pair) the output tuple corresponds to (0, x).

in the MNE version vmin and vmax can be callables that drive the operation,
but for the sake of simplicity this was not copied over.
"""
if vmax is None and vmin is None:
vmax = np.abs(data).max()
vmin = 0.0 if norm else -vmax
return vmin, vmax
Loading