Skip to content

Commit

Permalink
Refactor Stream class
Browse files Browse the repository at this point in the history
  • Loading branch information
toni-neurosc committed Sep 23, 2024
1 parent 45dee6c commit 2ab0bf8
Show file tree
Hide file tree
Showing 31 changed files with 506 additions and 429 deletions.
5 changes: 3 additions & 2 deletions examples/plot_0_first_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def generate_random_walk(NUM_CHANNELS, TIME_DATA_SAMPLES):
# DataFrame. There are some helper functions that let you create the
# nm_channels without much effort:

nm_channels = nm.utils.get_default_channels_from_data(data, car_rereferencing=True)
nm_channels = nm.utils.create_default_channels_from_data(data, car_rereferencing=True)

nm_channels

Expand Down Expand Up @@ -135,14 +135,15 @@ def generate_random_walk(NUM_CHANNELS, TIME_DATA_SAMPLES):
# We are now ready to go to instantiate the *Stream* and call the *run* method for feature estimation:

stream = nm.Stream(
data=data,
settings=settings,
channels=nm_channels,
verbose=True,
sfreq=sfreq,
line_noise=50,
)

features = stream.run(data, save_csv=True)
features = stream.run(save_csv=True)

# %%
# Feature Analysis
Expand Down
6 changes: 3 additions & 3 deletions examples/plot_1_example_BIDS.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
coord_names,
) = nm.io.read_BIDS_data(PATH_RUN=PATH_RUN)

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand Down Expand Up @@ -94,6 +94,8 @@

# %%
stream = nm.Stream(
data=data,
experiment_name=RUN_NAME,
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -105,9 +107,7 @@

# %%
features = stream.run(
data=data,
out_dir=PATH_OUT,
experiment_name=RUN_NAME,
save_csv=True,
)

Expand Down
19 changes: 11 additions & 8 deletions examples/plot_3_example_sharpwave_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
coord_names,
) = nm.io.read_BIDS_data(PATH_RUN=PATH_RUN)

print(data.shape)

# %%
settings = NMSettings.get_fast_compute()

Expand All @@ -69,7 +71,7 @@
for sw_feature in settings.sharpwave_analysis_settings.sharpwave_features.list_all():
settings.sharpwave_analysis_settings.estimator["mean"].append(sw_feature)

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand All @@ -79,7 +81,9 @@
target_keywords=["MOV_RIGHT"],
)

stream = nm.Stream(
data_plt = data[5, 1000:4000]

data_processor = nm.DataProcessor(
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -88,14 +92,12 @@
coord_names=coord_names,
verbose=False,
)
sw_analyzer = cast(
SharpwaveAnalyzer, stream.data_processor.features.get_feature("sharpwave_analysis")
)

sw_analyzer = data_processor.features.get_feature("sharpwave_analysis")


# %%
# The plotted example time series, visualized on a short time scale, shows the relation of identified peaks, troughs, and estimated features:
data_plt = data[5, 1000:4000]

filtered_dat = fftconvolve(data_plt, sw_analyzer.list_filter[0][1], mode="same")

troughs = signal.find_peaks(-filtered_dat, distance=10)[0]
Expand Down Expand Up @@ -297,6 +299,7 @@
channels.loc[[3, 8], "used"] = 1

stream = nm.Stream(
data=data[:, :30000],
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -306,7 +309,7 @@
verbose=True,
)

df_features = stream.run(data=data[:, :30000], save_csv=True)
df_features = stream.run(save_csv=True)

# %%
# We can then plot two exemplary features, prominence and interval, and see that the movement amplitude can be clustered with those two features alone:
Expand Down
6 changes: 3 additions & 3 deletions examples/plot_4_example_gridPointProjection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

settings.postprocessing.project_cortex = True

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand All @@ -65,6 +65,8 @@
)

stream = nm.Stream(
data=data[:, : int(sfreq * 5)],
experiment_name=RUN_NAME,
sfreq=sfreq,
channels=channels,
settings=settings,
Expand All @@ -75,9 +77,7 @@
)

features = stream.run(
data=data[:, : int(sfreq * 5)],
out_dir=PATH_OUT,
experiment_name=RUN_NAME,
save_csv=True,
)

Expand Down
1 change: 1 addition & 0 deletions examples/plot_6_real_time_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get_fast_compute_settings():
print("Computation time for single ECoG channel: ")
data = np.random.random([1, 1000])
stream = nm.Stream(sfreq=1000, data=data, sampling_rate_features_hz=10, verbose=False)

print(
f"{np.round(timeit.timeit(lambda: stream.data_processor.process(data), number=10)/10, 3)} s"
)
Expand Down
17 changes: 9 additions & 8 deletions examples/plot_7_lsl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# %%
from matplotlib import pyplot as plt
import py_neuromodulation as nm
import time

# %%
# Let’s get the example data from the provided BIDS dataset and create the channels DataFrame.
Expand All @@ -32,7 +33,7 @@
coord_names,
) = nm.io.read_BIDS_data(PATH_RUN=PATH_RUN)

channels = nm.utils.set_channels(
channels = nm.utils.create_channels(
ch_names=raw.ch_names,
ch_types=raw.get_channel_types(),
reference="default",
Expand Down Expand Up @@ -61,6 +62,9 @@
player = nm.stream.LSLOfflinePlayer(raw=raw, stream_name="example_stream")

player.start_player(chunk_size=30)

time.sleep(2) # Wait for stream to start

# %%
# Creating the LSLStream object
# -----------------------------
Expand All @@ -78,6 +82,9 @@
# %%
stream = nm.Stream(
sfreq=sfreq,
experiment_name=RUN_NAME,
is_stream_lsl=True,
stream_lsl_name="example_stream",
channels=channels,
settings=settings,
coord_list=coord_list,
Expand All @@ -87,13 +94,7 @@
# %%
# We then simply have to set the `stream_lsl` parameter to be `True` and specify the `stream_lsl_name`.

features = stream.run(
is_stream_lsl=True,
plot_lsl=False,
stream_lsl_name="example_stream",
out_dir=PATH_OUT,
experiment_name=RUN_NAME,
)
features = stream.run(out_dir=PATH_OUT)

# %%
# We can then look at the computed features and check if the streamed data was processed correctly.
Expand Down
1 change: 1 addition & 0 deletions py_neuromodulation/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@
FeatureProcessors,
add_custom_feature,
remove_custom_feature,
USE_FREQ_RANGES,
)
10 changes: 10 additions & 0 deletions py_neuromodulation/features/feature_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
import numpy as np
from py_neuromodulation import NMSettings

USE_FREQ_RANGES: list[FeatureName] = [
"bandpass_filter",
"stft",
"fft",
"welch",
"bursts",
"coherence",
"nolds",
"bispectrum",
]

FEATURE_DICT: dict[FeatureName | str, str] = {
"raw_hjorth": "Hjorth",
Expand Down
29 changes: 4 additions & 25 deletions py_neuromodulation/gui/backend/app_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Query,
WebSocket,
)
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware

Expand Down Expand Up @@ -71,7 +70,9 @@ def __init__(

def push_features_to_frontend(self, feature_queue: Queue) -> None:
while True:
time.sleep(0.002) # NOTE: should be adapted depending on feature sampling rate
time.sleep(
0.002
) # NOTE: should be adapted depending on feature sampling rate
if feature_queue.empty() is False:
self.logger.info("data in feature queue")
features = feature_queue.get()
Expand Down Expand Up @@ -231,7 +232,6 @@ async def setup_offline_stream(data: dict):
#######################

@self.get("/api/app-info")
# TODO: fix this function
async def get_app_info():
metadata = importlib.metadata.metadata("py_neuromodulation")
url_list = metadata.get_all("Project-URL")
Expand Down Expand Up @@ -353,25 +353,4 @@ def quick_access():
###########################
@self.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
# if self.websocket_manager.is_connected:
# self.logger.info(
# "WebSocket connection attempted while already connected"
# )
# await websocket.close(
# code=1008, reason="Another client is already connected"
# )
# return

await self.websocket_manager.connect(websocket)
# # #######################
# # ### SPA ENTRY POINT ###
# # #######################
# if not self.dev:

# @self.get("/app/{full_path:path}")
# async def serve_spa(request, full_path: str):
# # Serve the index.html for any path that doesn't match an API route
# print(Path.cwd())
# return FileResponse("frontend/index.html")


await self.websocket_manager.connect(websocket)
Loading

0 comments on commit 2ab0bf8

Please sign in to comment.