Skip to content

Commit

Permalink
adapted example GSD
Browse files Browse the repository at this point in the history
  • Loading branch information
Aksei committed Oct 4, 2024
1 parent 150b7ff commit 5cdec8d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 196 deletions.
8 changes: 8 additions & 0 deletions eargait/utils/example_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ def plot_image(image_path):
plt.axis("off")
plt.imshow(im_array)
plt.show()


def load_groundtruth(csv_path, target_sample_rate):
csv_table = pd.read_csv(csv_path)
downsampling_factor = 200 / target_sample_rate
csv_table["start"] = (csv_table["start"] / downsampling_factor).astype(int)
csv_table["stop"] = (csv_table["stop"] / downsampling_factor).astype(int)
return csv_table
27 changes: 0 additions & 27 deletions eargait/utils/overlapping_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,33 +303,6 @@ def _get_false_matches_from_overlap_data(overlaps: list[Interval], interval: Int
return f_intervals


"""def _get_false_matches_from_overlap_data(overlaps: list[Interval], interval: Interval) -> list[list[int]]: # noqa
f_intervals = []
for i, overlap in enumerate(overlaps):
prev_el = overlaps[i - 1] if i > 0 else None
next_el = overlaps[i + 1] if i < len(overlaps) - 1 else None
# check if there are false matches before the overlap
if interval.begin < overlap.begin:
fn_start = interval.begin
# check if interval is already covered by a previous overlap
if prev_el and interval.begin < prev_el.end:
fn_start = prev_el.end
f_intervals.append([fn_start, overlap.begin])
# check if there are false matches after the overlap
if interval.end > overlap.end:
fn_end = interval.end
# check if interval is already covered by a succeeding overlap
if next_el and interval.end > next_el.begin:
# skip because this will be handled by the next iteration
continue
# fn_end = next_el.begin
f_intervals.append([overlap.end, fn_end])
return f_intervals"""


def plot_categorized_intervals(
gsd_list_detected: pd.DataFrame, gsd_list_reference: pd.DataFrame, categorized_intervals: pd.DataFrame
) -> Figure:
Expand Down
186 changes: 19 additions & 167 deletions examples/gait_sequence_detector/gait_sequenece_detection_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,28 @@
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from signialib import Session

from eargait.gait_sequence_detection.gait_sequence_detector import GaitSequenceDetection
from eargait.utils.example_data import get_mat_example_data_path, load_groundtruth

# %%
# Getting example data
# -------------------------
# --------------------
#
# First, we import the necessary modules and load example data.
import pandas as pd

from eargait.gait_sequence_detection.gait_sequence_detector import GaitSequenceDetection
from eargait.utils.example_data import get_mat_example_data_path

# Path to data file (.mat) or .mat data directory
# Session should also work with txt file? so also gait detection?
repo_root = Path(__file__).resolve().parent.parent.parent
print("Repo Root", repo_root)

# Get example data and load it
data_path = get_mat_example_data_path()
csv_path = repo_root / "example_data/mat_files/walking_bout_indices.csv"

# %%
# Loading data
# ----------------
#
# A data session refers to a recording (by Signia Hearing Aids).
# A session can consist of a single `*.txt` or `*.mat` file, or two `*.mat` files, for left and right ear, respectively.
# The session is loaded using the local path `data_path` of the directory, in which the Matlab/txt file(s) are stored.
from signialib import Session
target_sample_rate = 50

session = Session.from_folder_path(data_path)
# Its recommended to NOT use skip_calibration = True, but use an up-to-date calibration file for the used Sensor.
align_calibrate_sess = session.align_calib_resample(resample_rate_hz=50, skip_calibration=True)
session.info
align_calibrate_sess = session.align_calib_resample(resample_rate_hz=target_sample_rate, skip_calibration=True)

# %%
# Gravity alignment and transformation into body frame
# ---------------------------------------------------------
# ----------------------------------------------------
#
# Align session to gravity and transform the coordinate system into a body frame.
# Two methods are provided: `StaticWindowGravityAlignment` and `TrimMeanGravityAlignment`.
Expand All @@ -55,14 +41,8 @@
from eargait.preprocessing import align_gravity_and_convert_ear_to_ebf
from eargait.utils import StaticWindowGravityAlignment, TrimMeanGravityAlignment

gravity_method = "static"
static_method = StaticWindowGravityAlignment(sampling_rate_hz=50)
trim_method = TrimMeanGravityAlignment(sampling_rate_hz=50)

if gravity_method == "static":
ear_data = align_gravity_and_convert_ear_to_ebf(align_calibrate_sess, static_method)
else:
ear_data = align_gravity_and_convert_ear_to_ebf(align_calibrate_sess, trim_method)
trim_method = TrimMeanGravityAlignment(sampling_rate_hz=target_sample_rate)
ear_data = align_gravity_and_convert_ear_to_ebf(align_calibrate_sess, trim_method)

# %%
# Initialize Gait Sequence Detection
Expand All @@ -72,7 +52,7 @@
# strictness=0 and minimum_sequence_length=1 are seen as standard.
# strictness is defined as >=0 while minimum_seq_length as >=1 definitions of these parameters in the Gsd class.

gsd = GaitSequenceDetection(sample_rate=50, strictness=0, minimum_seq_length=1)
gsd = GaitSequenceDetection(sample_rate=target_sample_rate, strictness=0, minimum_seq_length=1)

# %%
# Detect Gait Sequences
Expand Down Expand Up @@ -109,8 +89,9 @@
# The pipeline allows customization of parameters like `strictness` and `minimum_seq_length` to fine-tune the detection
# process based on the specific requirements of your dataset.

gsd = GaitSequenceDetection(sample_rate=50, strictness=0, minimum_seq_length=1)
gsd = GaitSequenceDetection(sample_rate=target_sample_rate, strictness=0, minimum_seq_length=1)

# %%
# Handling Multiple Activities
# ----------------------------
#
Expand All @@ -123,148 +104,19 @@
gsd.plot()


########################################################################################################################
# AB HIER ALLES WEG FÜR DAS MINIMAL BSP
# %%
# Further usefully analysis: Overlay Ground truth sequences and detected sequences
# --------------------------------------------------------------------------------
#
# To compare the detected sequences with eventually present Ground truth data we first need to make sure
# the Ground truth is present in the same sampling rate.

from eargait.utils.overlapping_regions import categorize_intervals, plot_categorized_intervals

def downsample_ground_truth(csv_path, target_sample_rate):
csv_table = pd.read_csv(csv_path)
downsampling_factor = 200 / target_sample_rate
csv_table["start"] = (csv_table["start"] / downsampling_factor).astype(int)
csv_table["stop"] = (csv_table["stop"] / downsampling_factor).astype(int)
return csv_table


tempo = get_mat_example_data_path().stem
csv_activity_table = downsample_ground_truth(csv_path, target_sample_rate=50)
csv_activity_table = csv_activity_table[csv_activity_table["speed"] == tempo]
csv_path_groundtruth = data_path.parent.joinpath("walking_bout_indices.csv")
csv_activity_table = load_groundtruth(csv_path_groundtruth, target_sample_rate=target_sample_rate)
csv_activity_table = csv_activity_table[csv_activity_table["speed"] == data_path.stem]
csv_activity_table = csv_activity_table.rename(columns={"stop": "end"})

# Plotting this overlays Ground truth and detected activity sequences.
gsd.plot(csv_activity_table)


# %%
# Percentual representation of overlap:
# -------------------------------------
#
# To have a one number expression of how good the detection worked we can display the percentual overlap of predicted
# and Ground truth sequences as the expresssion of true positive percentage of correctly identified walking sequences.

from eargait.utils.overlapping_regions import (
categorize_intervals,
categorize_intervals_per_sample,
plot_categorized_intervals,
)


def calculate_tp_percentage(detected_sequences, ground_truth_sequences):
categorized_intervals = categorize_intervals(detected_sequences, ground_truth_sequences)

ground_truth_duration = sum(ground_truth_sequences["end"] - ground_truth_sequences["start"])
true_positive_duration = sum(
categorized_intervals.tp_intervals["end"] - categorized_intervals.tp_intervals["start"]
)
tp_percentage_gt = (true_positive_duration / ground_truth_duration) * 100 if ground_truth_duration > 0 else 0
return tp_percentage_gt


detected_sequences = gsd.sequence_list_["left_sensor"][["start", "end"]]
ground_truth_sequences = csv_activity_table[["start", "end"]]
tp_percentage = calculate_tp_percentage(detected_sequences, ground_truth_sequences)

print(f"True Positive Percentage: {tp_percentage:.2f}%")


def calculate_sample_based_metrics(detected_sequences, ground_truth_sequences):
# Get categorized intervals (TP, FP, FN, TN) based on sample level matching
categorized_intervals2 = categorize_intervals_per_sample(
gsd_list_detected=detected_sequences,
gsd_list_reference=ground_truth_sequences,
)

# Ground truth duration (total duration of ground truth sequences)
ground_truth_duration = sum(ground_truth_sequences["end"] - ground_truth_sequences["start"])

# True Positives (TP)
tp_intervals = categorized_intervals2[categorized_intervals2["match_type"] == "tp"]
true_positive_duration = sum(tp_intervals["end"] - tp_intervals["start"])

# False Positives (FP)
fp_intervals = categorized_intervals2[categorized_intervals2["match_type"] == "fp"]
false_positive_duration = sum(fp_intervals["end"] - fp_intervals["start"])

# False Negatives (FN)
fn_intervals = categorized_intervals2[categorized_intervals2["match_type"] == "fn"]
false_negative_duration = sum(fn_intervals["end"] - fn_intervals["start"])

# True Negatives (TN), only calculated if total_samples is provided
true_negative_duration = 0
if total_samples is not None:
tn_intervals = categorized_intervals2[categorized_intervals2["match_type"] == "tn"]
true_negative_duration = sum(tn_intervals["end"] - tn_intervals["start"])

# Calculate percentages for TP, FP, FN, and TN based on ground truth duration
metrics = {
"TP_percentage": (true_positive_duration / ground_truth_duration) * 100 if ground_truth_duration > 0 else 0,
"FP_percentage": (false_positive_duration / ground_truth_duration) * 100 if ground_truth_duration > 0 else 0,
"FN_percentage": (false_negative_duration / ground_truth_duration) * 100 if ground_truth_duration > 0 else 0,
"TN_percentage": (true_negative_duration / ground_truth_duration) * 100 if ground_truth_duration > 0 else 0,
}

# Print the metrics
print(f"True Positive Percentage (TP): {metrics['TP_percentage']:.2f}%")
print(f"False Positive Percentage (FP): {metrics['FP_percentage']:.2f}%")
print(f"False Negative Percentage (FN): {metrics['FN_percentage']:.2f}%")
if total_samples is not None:
print(f"True Negative Percentage (TN): {metrics['TN_percentage']:.2f}%")
else:
print("True Negative Percentage (TN): Not calculated (provide total_samples for TN calculation)")

return categorized_intervals2, metrics

# Example usage


detected_sequences = gsd.sequence_list_["left_sensor"][["start", "end"]]
ground_truth_sequences = csv_activity_table[["start", "end"]]
total_samples = len(detected_sequences) # You can replace this with the actual total sample count

categorized_intervals2, metrics = calculate_sample_based_metrics(detected_sequences, ground_truth_sequences)
plot_categorized_intervals(
gsd_list_detected=detected_sequences,
gsd_list_reference=ground_truth_sequences,
categorized_intervals=categorized_intervals2,
)
plt.show()

########################################################################################################################
# Strictness and min_length Parameter
# ----------------------------
#
# strictness
# Determines the size of the gap (in number of windows) at which two consecutive sequences are linked
# together to a single sequence.
# minimum_seq_length
# Determines the minimum length of a sequence (in windows). Needs to be >= 1.
sequence = pd.DataFrame(
{
"start": [2100, 2550, 3750, 4750, 6000, 7300, 8250, 9300, 10200],
"end": [2550, 3600, 4450, 5850, 6900, 7950, 9000, 9450, 10350],
}
)
print("Original Seqeuence:", sequence)
sample_rate = 50
strictness = 2
minimum_seq_length = 2
gsd = GaitSequenceDetection(sample_rate=sample_rate, strictness=strictness, minimum_seq_length=minimum_seq_length)
sequence = gsd._ensure_strictness(sequence)
print("Seqeunce after Strictness criterion:", sequence)
sequence = gsd._ensure_minimum_length(sequence)
print("Sequence after min_length criterion:", sequence)
2 changes: 1 addition & 1 deletion examples/load_data/README.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.. _examples-load_data:

Load Data by Signia Hearing Aids
===============================
================================
Demonstrations of different loading functionalities.
2 changes: 1 addition & 1 deletion examples/load_data/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
.. _example_load_data:
Load Data by Signia Hearing Aids
===============================
================================
This example shows you how to load data recorded with Signia hearing aids.
Please note, that the privat python package signialib is required for running this example.
Expand Down

0 comments on commit 5cdec8d

Please sign in to comment.