diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml index c72e2fdd..80d79b6c 100644 --- a/py_neuromodulation/default_settings.yaml +++ b/py_neuromodulation/default_settings.yaml @@ -193,6 +193,7 @@ coherence_settings: method: coh: true icoh: true + nperseg: 128 fooof_settings: aperiodic: diff --git a/py_neuromodulation/features/coherence.py b/py_neuromodulation/features/coherence.py index f3af5d8d..67b01b42 100644 --- a/py_neuromodulation/features/coherence.py +++ b/py_neuromodulation/features/coherence.py @@ -26,7 +26,7 @@ class CoherenceFeatures(BoolSelector): mean_fband: bool = True max_fband: bool = True max_allfbands: bool = True - + ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)] @@ -35,6 +35,7 @@ class CoherenceSettings(NMBaseModel): features: CoherenceFeatures = CoherenceFeatures() method: CoherenceMethods = CoherenceMethods() channels: list[ListOfTwoStr] = [] + nperseg: int = Field(default=128, ge=0) frequency_bands: list[str] = Field(default=["high_beta"], min_length=1) @field_validator("frequency_bands") @@ -49,6 +50,7 @@ def __init__( window: str, fbands: list[FrequencyRange], fband_names: list[str], + nperseg: int, ch_1_name: str, ch_2_name: str, ch_1_idx: int, @@ -65,6 +67,7 @@ def __init__( self.ch_2 = ch_2_name self.ch_1_idx = ch_1_idx self.ch_2_idx = ch_2_idx + self.nperseg = nperseg self.coh = coh self.icoh = icoh self.features_coh = features_coh @@ -79,9 +82,9 @@ def __init__( def get_coh(self, feature_results, x, y): from scipy.signal import welch, csd - self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=128) - self.Pyy = welch(y, self.sfreq, self.window, nperseg=128)[1] - self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=128)[1] + self.f, self.Pxx = welch(x, self.sfreq, self.window, nperseg=self.nperseg) + self.Pyy = welch(y, self.sfreq, self.window, nperseg=self.nperseg)[1] + self.Pxy = csd(x, y, self.sfreq, self.window, nperseg=self.nperseg)[1] if self.coh and self.icoh: cohy = self.Pxy / np.sqrt(self.Pxx * self.Pyy) @@ -184,6 +187,7 @@ def __init__( "hann", fband_specs, fband_names, + self.settings.nperseg, ch_1_name, ch_2_name, ch_1_idx, diff --git a/tests/test_coherence.py b/tests/test_coherence.py index 47e74ccb..50fde363 100644 --- a/tests/test_coherence.py +++ b/tests/test_coherence.py @@ -70,7 +70,8 @@ def test_coherence(): # connectivity indices, i.e.: ([0, 1], [1, 0]) # FIXME: indices with only 1 channel in seeds/targets raises error with type validatn - settings.coherence_settings.channels = [ch_names, ch_names[::-1]] + #settings.coherence_settings.channels = [ch_names, ch_names[::-1]] + settings.coherence_settings.channels = [ch_names] # , ch_names[::-1] # do not normalise features for this test! # (normalisation changes interpretability of connectivity values, making it harder to