Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing topomaps. #130

Merged
merged 10 commits into from
Oct 26, 2023
2 changes: 1 addition & 1 deletion pylossless/dash/tests/test_topo_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_GridTopoPlot():
offset = 2
nb_topo = 4
plot_data = topo_data.topo_values.iloc[::-1].iloc[offset : offset + nb_topo]
plot_data = list(plot_data.T.to_dict().values())
plot_data = plot_data.values.tolist()
christian-oreilly marked this conversation as resolved.
Show resolved Hide resolved

GridTopoPlot(
2,
Expand Down
55 changes: 38 additions & 17 deletions pylossless/dash/topo_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
yaxis.update({"scaleanchor": "x", "scaleratio": 1})


def pick_montage(montage, ch_names):
"""Pick a subset of channels from a montage."""
digs = montage.remove_fiducials().dig
assert len(digs) == len(montage.ch_names)
christian-oreilly marked this conversation as resolved.
Show resolved Hide resolved
digs = [dig for dig, ch_name in zip(digs, montage.ch_names) if ch_name in ch_names]
return mne.channels.DigMontage(dig=digs, ch_names=ch_names)


class TopoPlot: # TODO: Fix/finish doc comments for this class.
"""Representation of a classic EEG topographic map as a plotly figure."""

Expand All @@ -47,7 +55,7 @@ def __init__(
res=64,
width=None,
height=None,
cmap="RdBu_r",
cmap=None,
show_sensors=True,
colorbar=False,
):
Expand Down Expand Up @@ -162,9 +170,11 @@ def set_data(self, data):
self.info = create_info(names, sfreq=256, ch_types="eeg")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# To update self.info with channels positions
RawArray(
np.zeros((len(names), 1)), self.info, copy=None, verbose=False
).set_montage(self.montage)
assert np.all(np.array(names) == np.array(self.info.ch_names))
self.set_head_pos_contours()

# TODO: Finish/fix docstring
Expand Down Expand Up @@ -278,12 +288,24 @@ def plot_topo(self, **kwargs):
-------
A plotly.graph_objects.Figure object.
"""
from .utils import _setup_vmin_vmax

if self.__data is None:
return

data = np.array(list(self.__data.values()))
norm = min(np.array(data)) >= 0
vmin, vmax = _setup_vmin_vmax(data, None, None, norm)
if self.cmap is None:
cmap = "Reds" if norm else "RdBu_r"
else:
cmap = self.cmap

heatmap_trace = go.Heatmap(
showscale=self.colorbar,
colorscale=self.cmap,
colorscale=cmap,
zmin=vmin,
zmax=vmax,
**self.get_heatmap_data(**kwargs)
)

Expand Down Expand Up @@ -341,9 +363,10 @@ def __init__(
See mne.channels.make_standard_montage(), and
mne.channels.get_builtin_montages() for more information
on making montage objects in MNE.
data : mne.preprocessing.ICA | None
The data to use for the topoplots. Can be an instance of
mne.preprocessing.ICA.
data : list | None
The data to use for the topoplots. Should be a list of
dictionaries, one per topomap. The dictionaries should
have the channel names as keys.
figure : plotly.graph_objects.Figure | None
Figure to use (if not None) for plotting.
color : str
Expand Down Expand Up @@ -635,13 +658,13 @@ def initialize_layout(self, slider_val=None, show_sensors=True):

# The indexing with ch_names is to ensure the order
# of the channels are compatible between plot_data and the montage
ch_names = [
ch_name
for ch_name in self.montage.ch_names
if ch_name in self.data.topo_values.columns
montage = pick_montage(self.montage, self.data.topo_values.columns)
ch_names = montage.ch_names
assert len(ch_names) == len(self.data.topo_values.columns)
christian-oreilly marked this conversation as resolved.
Show resolved Hide resolved
assert np.sum(np.in1d(ch_names, self.data.topo_values.columns)) == len(ch_names)
christian-oreilly marked this conversation as resolved.
Show resolved Hide resolved
plot_data = [
OrderedDict(self.data.topo_values.loc[title, ch_names]) for title in titles
]
plot_data = self.data.topo_values.loc[titles, ch_names]
plot_data = list(plot_data.T.to_dict().values())

if len(plot_data) < self.nb_sel_topo:
nb_missing_topo = self.nb_sel_topo - len(plot_data)
Expand All @@ -650,7 +673,7 @@ def initialize_layout(self, slider_val=None, show_sensors=True):
self.figure = GridTopoPlot(
rows=self.rows,
cols=self.cols,
montage=self.montage,
montage=montage,
data=plot_data,
color=colors,
res=self.res,
Expand Down Expand Up @@ -807,16 +830,14 @@ def init_vars(self, montage, ica, ic_labels):
return None

data = TopoData(
[
dict(zip(montage.ch_names, component))
for component in ica.get_components().T
]
[dict(zip(ica.ch_names, component)) for component in ica.get_components().T]
)
data.topo_values.index = ica._ica_names

if ic_labels:
self.head_contours_color = {
comp: ic_label_cmap[label] for comp, label in ic_labels.items()
}
data.topo_values.index = list(ic_labels.keys())
return data

def load_recording(self, montage, ica, ic_labels):
Expand Down
24 changes: 24 additions & 0 deletions pylossless/dash/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np


def _setup_vmin_vmax(data, vmin, vmax, norm=False):
"""Handle vmin and vmax parameters for visualizing topomaps.

This is a simplified copy of mne.viz.utils._setup_vmin_vmax.
https://github.com/mne-tools/mne-python/blob/main/mne/viz/utils.py

Notes
-----
For the normal use-case (when `vmin` and `vmax` are None), the parameter
`norm` drives the computation. When norm=False, data is supposed to come
from a mag and the output tuple (vmin, vmax) is symmetric range
(-x, x) where x is the max(abs(data)). When norm=True (a.k.a. data is the
L2 norm of a gradiometer pair) the output tuple corresponds to (0, x).

in the MNE version vmin and vmax can be callables that drive the operation,
but for the sake of simplicity this was not copied over.
"""
if vmax is None and vmin is None:
vmax = np.abs(data).max()
vmin = 0.0 if norm else -vmax
return vmin, vmax