Skip to content

Commit

Permalink
Timing lsl (#337)
Browse files Browse the repository at this point in the history
* Create Settings class with Pydantic validation

* Fixed a lot of bugs

* add pydantic dependency

* Remove StrEnum

* Partially fixed tests

* Refactor oscillatory

* Fix oscillatory bug

* vectorize_bispectra

* More validation and settings refcatoring

* minor changes

* Settings class validation finished, all tests working

* Ignore wrong order modules in __init__.py

* Fix tests, added retrocompability functions

* Add Pydantic dependency

* Autoformat

* Vectorize bursts calculation

* Fix bug

* autoformat

* typing change

* fix enable_all_features

* vectorize bursts calculation, fix settings reset bug

* review pydantic PR

* fix bursts merge, fix osc tests

* Remove getattr/setattr where possible

* fix bug in sharpwaves

* Add some comments, remove seglength

* Replace nm_settings.json for a .yaml

* Refactor normalizer classes

* add  more comments about burst vectorization

* fix timing and logging in lsl stream

* Sharpwave refactor and bugfixing

* Speed-up sharpwaves computation

* fix sharpeness bug

* fix nolds  default settings fband name

* remove prints

* Make MNE-connectivity non-verbose

* Refactor nm_normalization + Handle divide by zero in feature normalization

* Remove a file

* Fix Numpy private function calls for Numpy 2.0 update

* add new dependencies for numpy and scipy to be compatible with numpy _core methods

* add function to read mne io read_raw data

* fix doc example sharpwave

* fix add feature example

* changes in timing, exclude parllel processing

* Update test_feature_sampling_rates.py

Remove comments in tests

* Update test_timing.py

* fix test

---------

Co-authored-by: Antonio Martinez Brotons <[email protected]>
Co-authored-by: timonmerk <[email protected]>
  • Loading branch information
3 people authored Jun 17, 2024
1 parent b7357bd commit 022402b
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 199 deletions.
2 changes: 1 addition & 1 deletion py_neuromodulation/nm_mnelsl_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self._path_raw = Path.cwd() / "temp_raw.fif"
raw.save(self._path_raw, overwrite=True)

def start_player(self, chunk_size: int = 1, n_repeat: int = 1):
def start_player(self, chunk_size: int = 10, n_repeat: int = 1):
"""Start MNE-LSL Player
Parameters
Expand Down
8 changes: 7 additions & 1 deletion py_neuromodulation/nm_mnelsl_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
self.last_time = time.time()
check_data = None
data = None
stream_start_time = None

while self.stream.connected:
time_diff = time.time() - self.last_time # in s
time.sleep(0.005)
if time_diff >= self.sampling_interval:
self.last_time = time.time()

logger.info(f"Pull data - current time: {self.last_time}")
logger.debug(f"Pull data - current time: {self.last_time}")
logger.debug(f"time since last data pull {time_diff} seconds")

if time_diff >= 2 * self.sampling_interval:
logger.warning(
Expand All @@ -86,6 +88,8 @@ def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
check_data = data

data, timestamp = self.stream.get_data(winsize=self.winsize)
if stream_start_time is None:
stream_start_time = timestamp[0]

for i in range(self._n_seconds_wait_before_disconnect):
if (
Expand All @@ -105,6 +109,8 @@ def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:

yield timestamp, data

logger.info(f"Stream time: {timestamp[-1] - stream_start_time}")

if not self.headless and not self.listener.running:
logger.info("Keyboard interrupt")
self.stream.disconnect()
2 changes: 1 addition & 1 deletion py_neuromodulation/nm_run_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def process(self, data: np.ndarray) -> dict[str, float]:
features_dict[ch] = np.nan

if self.verbose:
logger.info("Last batch took: %.2f seconds", time() - start_time)
logger.info("Last batch took: %.3f seconds to process", time() - start_time)

return features_dict

Expand Down
2 changes: 1 addition & 1 deletion py_neuromodulation/nm_sharpwaves.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def analyze_waveform(self, data) -> dict:
left_height = data[troughs_valid - int(5 * (1000 / self.sfreq))]
right_height = data[troughs_valid + int(5 * (1000 / self.sfreq))]
# results["sharpness"] = ((trough_height - left_height) + (trough_height - right_height)) / 2
results["sharpness"] = trough_height - 0.5 * (left_height + right_height)
results["sharpness"] = trough_height - 0.5 * (left_height + right_height)

if self.need_steepness:
# steepness is calculated as the first derivative
Expand Down
6 changes: 0 additions & 6 deletions py_neuromodulation/nm_stream_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,6 @@ def run(self) -> pd.DataFrame:
verbose=self.verbose,
)

@abstractmethod
def _add_timestamp(self, feature_series: pd.Series, cnt_samples: int) -> pd.Series:
"""Add to feature_series "time" keyword
For Bids specify with fs_features, for real time analysis with current time stamp
"""

@staticmethod
def _get_sess_lat(coords: dict) -> bool:
if len(coords["cortex_left"]["positions"]) == 0:
Expand Down
116 changes: 25 additions & 91 deletions py_neuromodulation/nm_stream_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,6 @@ def _add_target(self, feature_dict: dict, data: np.ndarray) -> None:
for target_idx, target_name in zip(self.target_indexes, self.target_names):
feature_dict[target_name] = data[target_idx, -1]

def _add_timestamp(self, feature_dict: dict, cnt_samples: int) -> None:
"""Add time stamp in ms.
Due to normalization DataProcessor needs to keep track of the counted
samples. These are accessed here for time conversion.
"""
timestamp = cnt_samples * 1000 / self.sfreq
feature_dict["time"] = timestamp

if self.verbose:
logger.info("%.2f seconds of data processed", timestamp / 1000)

def _handle_data(self, data: np.ndarray | pd.DataFrame) -> np.ndarray:
names_expected = self.nm_channels["name"].to_list()

Expand All @@ -87,35 +75,6 @@ def _handle_data(self, data: np.ndarray | pd.DataFrame) -> np.ndarray:
)
return data.to_numpy().transpose()

def _check_settings_for_parallel(self):
"""Check specified settings and raise error if parallel processing is not possible.
Raises:
ValueError: depending on the settings, parallel processing is not possible
"""

if "raw_normalization" in self.settings.preprocessing:
raise ValueError(
"Parallel processing is not possible with raw_normalization normalization."
)
if self.settings.postprocessing.feature_normalization:
raise ValueError(
"Parallel processing is not possible with feature normalization."
)
if self.settings.features.bursts:
raise ValueError(
"Parallel processing is not possible with burst estimation."
)

def _process_batch(self, data_batch, cnt_samples):
# if isinstance(data_batch, tuple):
# data_batch = np.array(data_batch[1])

feature_dict = self.data_processor.process(data_batch[1].astype(np.float64))
self._add_timestamp(feature_dict, cnt_samples)
self._add_target(feature_dict, data_batch[1])
return feature_dict

def _run(
self,
data: np.ndarray | pd.DataFrame | None = None,
Expand All @@ -124,8 +83,6 @@ def _run(
is_stream_lsl: bool = True,
stream_lsl_name: str = None,
plot_lsl: bool = False,
parallel: bool = False,
n_jobs: int = -2,
) -> pd.DataFrame:
from py_neuromodulation.nm_generator import raw_data_generator

Expand Down Expand Up @@ -159,54 +116,38 @@ def _run(

generator = self.lsl_stream.get_next_batch()

sample_add = self.sfreq / self.data_processor.sfreq_features
l_features: list[dict] = []
last_time = None

offset_time = self.settings.segment_length_features_ms
# offset_start = np.ceil(offset_time / 1000 * self.sfreq).astype(int)
offset_start = offset_time / 1000 * self.sfreq
while True:
next_item = next(generator, None)

if parallel:
from joblib import Parallel, delayed
from itertools import count
if next_item is not None:
time_, data_batch = next_item
else:
break

# parallel processing can not be utilized if a LSL stream is used
if is_stream_lsl:
error_msg = "Parallel processing is not possible with LSL stream."
logger.error(error_msg)
raise ValueError(error_msg)

l_features = Parallel(n_jobs=n_jobs, verbose=10)(
delayed(self._process_batch)(data_batch, cnt_samples)
for data_batch, cnt_samples in zip(
generator, count(offset_start, sample_add)
)
if data_batch is None:
break
feature_dict = self.data_processor.process(
data_batch.astype(np.float64)
)
if is_stream_lsl:
feature_dict["time"] = time_[-1]
if self.verbose:
if last_time is not None:
logger.debug("%.3f seconds of new data processed", time_[-1] - last_time)
last_time = time_[-1]
else:
feature_dict["time"] = np.ceil(time_[-1] * 1000 +1 ).astype(int)
logger.info("Time: %.2f", feature_dict["time"]/1000)


else:
l_features: list[dict] = []
cnt_samples = offset_start

while True:
next_item = next(generator, None)

if next_item is not None:
time_, data_batch = next_item
else:
break

if data_batch is None:
break
feature_dict = self.data_processor.process(
data_batch.astype(np.float64)
)
self._add_timestamp(feature_dict, cnt_samples)
self._add_target(feature_dict, data_batch)

l_features.append(feature_dict)
self._add_target(feature_dict, data_batch)

cnt_samples += sample_add
l_features.append(feature_dict)

feature_df = pd.DataFrame.from_records(l_features).astype(np.float64)
feature_df = pd.DataFrame(l_features)

self.save_after_stream(out_path_root, folder_name, feature_df)

Expand Down Expand Up @@ -344,8 +285,6 @@ def run(
data: np.ndarray | pd.DataFrame | None = None,
out_path_root: _PathLike = Path.cwd(),
folder_name: str = "sub",
parallel: bool = False,
n_jobs: int = -2,
stream_lsl: bool = False,
stream_lsl_name: str = None,
plot_lsl: bool = False,
Expand Down Expand Up @@ -380,9 +319,6 @@ def run(
elif self.data is None and data is None and self.stream_lsl is False:
raise ValueError("No data passed to run function.")

if parallel:
self._check_settings_for_parallel()

out_path = Path(out_path_root, folder_name)
out_path.mkdir(parents=True, exist_ok=True)
logger.log_to_file(out_path)
Expand All @@ -391,8 +327,6 @@ def run(
data,
out_path_root,
folder_name,
parallel=parallel,
n_jobs=n_jobs,
is_stream_lsl=stream_lsl,
stream_lsl_name=stream_lsl_name,
plot_lsl=plot_lsl,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_feature_sampling_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def test_different_sampling_rate_100Hz():
df = stream.run(arr_test)

# check the difference between time points
#print(df["time"].iloc[:2])
#print(sampling_rate_features)

assert np.diff(df["time"].iloc[:2]) / 1000 == (1 / sampling_rate_features)

Expand Down
34 changes: 0 additions & 34 deletions tests/test_lsl_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,6 @@
import time
import threading


def test_mne_lsl(setup_default_data):
""" Test the mne-lsl package and the core functionality of the player and stream classes. """
raw, data, sfreq = setup_default_data

player1 = PlayerLSL(raw, name="general_lsl_test_stream", chunk_size=10)
player1 = player1.start()

stream1 = StreamLSL(name="general_lsl_test_stream", bufsize=2).connect()
ch_types = stream1.get_channel_types(unique=True)
assert 'dbs' in ch_types, "Expected at least one dbs channel from example data"
assert player1.info['nchan'] == 10, "Expected 10 channels from example data"
data_l = []
timestamps_l = []

def call_every_100ms():
data, timestamps = stream1.get_data(winsize=10)
data_l.append(data)
timestamps_l.append(timestamps)

t = threading.Timer(0.1, call_every_100ms)
t.start()

time_start = time.time()

while time.time() - time_start <= 10:
time.sleep(1)
t.cancel()

collected_data_shape = np.concatenate(data_l).shape
assert collected_data_shape[0] > 0 and collected_data_shape[1] > 0, "Expected non-empty data"

stream1.disconnect()
player1.stop()

@pytest.mark.parametrize('setup_lsl_player', ['search'], indirect=True)
def test_lsl_stream_search(setup_lsl_player):
Expand Down
63 changes: 0 additions & 63 deletions tests/test_multiprocessing.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ def test_float_fs():
features.time.iloc[1] - features.time.iloc[0]
) == 1000 / sampling_rate_features_hz

assert features.time.iloc[0] == settings.segment_length_features_ms
assert features["time"].iloc[0] - 1 == settings["segment_length_features_ms"] # remove 1 due to python counting

0 comments on commit 022402b

Please sign in to comment.