Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potentially Implement additional WBA rules #36

Open
AKuederle opened this issue Oct 11, 2023 · 2 comments
Open

Potentially Implement additional WBA rules #36

AKuederle opened this issue Oct 11, 2023 · 2 comments

Comments

@AKuederle
Copy link
Contributor

AKuederle commented Oct 11, 2023

During the process of reimplementing the WBA, I removed all rules that are not relevant anymore considering that we don't filter WBs based on their elevation change or the existance of turns anymore.

In this context we also removed the concepts of events from the WBA to make things simpler. This can be reintroduced once the below rules are reimplemented.

However, It might still be interesting to implement versions of these rules in the future.
Below is the source code of these rules.
Note, that this code will not work with the version of the WBA implemented in gaitlink at the moment.
But they can serve as a starting point to finalize the implementations of these rules:

class EventCriteria:
    event_type: str

    def filter_events(self, event_list):
        if event_list is None:
            raise ValueError("You are using an event based Criteria, without providing any events.")
        events = next(
            (event["events"] for event in event_list if event["name"] == self.event_type),
            [],
        )
        events = self._filter_events(events)
        return events

    @staticmethod
    def _convert_to_start_stop(events):
        event_start_end = np.array([[event["start"], event["end"]] for event in events])
        event_start_end = event_start_end[event_start_end[:, 0].argsort()]
        return event_start_end

    def _filter_events(self, event_list: list[dict]) -> list[dict]:
        return event_list


# TODO: "Precompile relevant event list for performance
class EventTerminationCriteria(BaseWBCriteria, EventCriteria):
    """Terminate/prevent starts of WBs in case strides overlap with events.

    Parameters
    ----------
    event_type
        Which type of event (from the event list) should be used
    termination_mode
        This controls under which circumstances a stride should be considered invalid for the WB.

        "start": A stride is invalid, if since the beginning of the last stride a new event was started.
        "end": A stride is invalid, if an event was ended since the beginning of the last stride
        "both": A stride is invalid, if any of the above conditions applied
        "ongoing": A stride is invalid, if it has any overlap with any event.

        At the very beginning of a recording, the start of the last stride is equal to the start of the recording.

    """

    termination_mode: str

    _termination_modes = (
        "start",
        "end",
        "both",
        "ongoing",
    )

    def __init__(
        self,
        event_type: str,
        termination_mode: Literal["start", "end", "both", "ongoing"] = "ongoing",
        comment: Optional[str] = None,
    ) -> None:
        self.event_type = event_type
        self.termination_mode = termination_mode
        super().__init__(comment=comment)

    def check_wb_start_end(
        self,
        stride_list: list[dict],
        original_start: int,
        current_start: int,
        current_end: int,
        event_list: Optional[list[dict]] = None,
    ) -> tuple[Optional[int], Optional[int]]:
        last_stride_start = 0
        if current_end >= 1:
            last_stride_start = stride_list[current_end - 1]["start"]
        current_stride_end = stride_list[current_end]["end"]
        current_stride_valid = self._check_stride(last_stride_start, current_stride_end, event_list)
        if current_stride_valid:
            return None, None
        # If the current stride is not valid, we either need to delay the start, if no proper WB has been started yet,
        # or we need to terminate the WB.
        if current_end == current_start:
            # Prevent the start
            return current_end + 1, None
        # Terminate the WB
        return None, current_end - 1

    def _check_stride(self, last_stride_start: float, current_stride_end: float, event_list):
        if self.termination_mode not in self._termination_modes:
            # We do this check here to avoid computing the events if the termination mode is invalid
            raise ValueError(f'"termination_mode" must be one of {self._termination_modes}')
        events = self.filter_events(event_list)
        if not events:
            return True
        event_start_end = self._convert_to_start_stop(events)

        events_started_since_last_stride = event_start_end[
            np.nonzero((event_start_end[:, 0] >= last_stride_start) & (event_start_end[:, 0] < current_stride_end))
        ]

        events_ended_since_last_stride = event_start_end[
            np.nonzero((event_start_end[:, 1] >= last_stride_start) & (event_start_end[:, 1] < current_stride_end))
        ]

        if self.termination_mode == "start":
            return len(events_started_since_last_stride) == 0
        if self.termination_mode == "end":
            return len(events_ended_since_last_stride) == 0
        if self.termination_mode == "both":
            return len(events_started_since_last_stride) == 0 and len(events_ended_since_last_stride) == 0
        if self.termination_mode == "ongoing":
            # Find events that where started before and are still ongoing
            ongoing_events = event_start_end[
                np.nonzero((event_start_end[:, 0] <= last_stride_start) & (event_start_end[:, 1] >= current_stride_end))
            ]
            return len(events_started_since_last_stride) == 0 and len(ongoing_events) == 0
        # We never reach this point
        raise ValueError()


class EventInclusionCriteria(BaseWBCriteria, EventCriteria):
    """Test if a WB is fully or partially covered by an event."""

    event_type: str
    overlap: str

    _overlap_types = ("partial", "contains", "is_contained", "no_overlap")

    def __init__(
        self,
        event_type: str,
        overlap: str = "partial",
        comment: Optional[str] = None,
    ) -> None:
        self.overlap = overlap
        self.event_type = event_type

        super().__init__(comment=comment)

    def check_include(self, wb: dict, event_list: Optional[list[dict]] = None) -> bool:
        # TODO: TEST
        events = self.filter_events(event_list)
        if not events:
            return True
        event_start_end = self._convert_to_start_stop(events)

        min_ends = np.minimum(event_start_end[:, 1], wb["end"])
        max_start = np.maximum(event_start_end[:, 0], wb["start"])
        amount_overlap = min_ends - max_start

        if self.overlap == "contains":
            return len(event_start_end[amount_overlap >= event_start_end[:, 1] - event_start_end[:, 0]]) > 0
        if self.overlap == "is_contained":
            return len(event_start_end[amount_overlap >= wb["end"] - wb["start"]]) > 0
        if self.overlap == "no_overlap":
            return len(event_start_end[amount_overlap > 0]) == 0
        if self.overlap == "partial":
            return len(event_start_end[amount_overlap > 0]) > 0
        raise ValueError(f'"overlap" must be one of {self._overlap_types}')
        
class LevelWalkingCriteria(BaseWBCriteria):
    """Test if WB has no more than N consecutive non-level strides.

    A WB is terminated if there are more than N consecutive strides that are not level walking.
    """
    max_non_level_strides: Optional[int]
    max_non_level_strides_left: Optional[int]
    max_non_level_strides_right: Optional[int]
    level_walking_threshold: float

    @property
    def _max_lag(self) -> int:
        non_none_vals = [
            val
            for val in [self.max_non_level_strides, self.max_non_level_strides_left, self.max_non_level_strides_right]
            if val
        ]
        return max(non_none_vals)

    def __init__(
        self,
        level_walking_threshold: float,
        max_non_level_strides: Optional[int] = None,
        max_non_level_strides_left: Optional[int] = None,
        max_non_level_strides_right: Optional[float] = None,
        field_name: str = "elevation",
        comment: Optional[str] = None,
    ) -> None:
        self.max_non_level_strides = max_non_level_strides
        self.max_non_level_strides_left = max_non_level_strides_left
        self.max_non_level_strides_right = max_non_level_strides_right
        if level_walking_threshold < 0:
            raise ValueError("`level_walking_threshold` must be >0.")
        self.level_walking_threshold = level_walking_threshold
        self.field_name = field_name
        super().__init__(comment=comment)

    def check_wb_start_end(
        self,
        stride_list: list[dict],
        original_start: int,
        current_start: int,
        current_end: int,
        event_list: Optional[list[dict]] = None,
    ) -> tuple[Optional[int], Optional[int]]:
        past_strides = stride_list[original_start : current_end + 1]
        consecutive_section = []
        for stride in reversed(past_strides):
            # We consider nan values always as level  walking!
            stride_height_change = abs(stride["parameter"][self.field_name])
            if not np.isnan(stride_height_change) and stride_height_change >= self.level_walking_threshold:
                consecutive_section.insert(0, stride)
            else:
                break
        if not consecutive_section:
            return None, None

        is_non_level = self._check_subsequence(consecutive_section)
        if is_non_level:
            # If we are at the beginning of the WB, we will change the start.
            if len(consecutive_section) == len(past_strides):
                return current_end + 1, None

            # If we are in the middle of a WB, we want to terminate it
            return None, current_end - len(consecutive_section)

        return None, None

    def _check_subsequence(self, stride_list) -> bool:
        """Check if the detected part exceeds our thresholds."""
        if self.max_non_level_strides is not None:
            return len(stride_list) >= self.max_non_level_strides
        if self.max_non_level_strides_left is None and self.max_non_level_strides_right is None:
            return False
        foot = [s["foot"] for s in stride_list]
        foot_count = Counter(foot)
        if self.max_non_level_strides_left is not None:
            return foot_count["left"] >= self.max_non_level_strides_left
        if self.max_non_level_strides_right is not None:
            return foot_count["right"] >= self.max_non_level_strides_right
        return False


class TurnAngleCriteria(EventTerminationCriteria):
    event_type: str = "turn"
    min_turn_angle: float
    max_turn_angle: float
    min_turn_rate: float
    max_turn_rate: float

    _serializable_paras = (
        "min_turn_angle",
        "max_turn_angle",
        "min_turn_rate",
        "max_turn_rate",
    )

    _rule_type: str = "turn_event"

    def __init__(
        self,
        min_turn_angle: Optional[float] = None,
        max_turn_angle: Optional[float] = None,
        min_turn_rate: Optional[float] = None,
        max_turn_rate: Optional[float] = None,
        comment: Optional[str] = None,
    ) -> None:
        self.min_turn_angle = min_turn_angle
        self.max_turn_angle = max_turn_angle
        self.min_turn_rate = min_turn_rate
        self.max_turn_rate = max_turn_rate
        super().__init__(
            event_type=self.event_type,
            termination_mode="ongoing",
            comment=comment,
        )

    def _filter_events(self, event_list: list[dict]) -> list[dict]:
        min_turn_angle, max_turn_angle = check_thresholds(
            self.min_turn_angle, self.max_turn_angle, allow_both_none=True
        )
        min_turn_rate, max_turn_rate = check_thresholds(self.min_turn_rate, self.max_turn_rate, allow_both_none=True)
        valid_events = []
        for e in event_list:
            if (min_turn_angle <= np.abs(e["parameter"]["angle"]) <= max_turn_angle) and (
                ((e["end"] - e["start"]) > 0)
                and min_turn_rate <= np.abs(e["parameter"]["angle"]) / (e["end"] - e["start"]) <= max_turn_rate
            ):
                valid_events.append(e)
        return valid_events
@AKuederle
Copy link
Contributor Author

A list of tests that were used to test the functionality of these rules:

@pytest.mark.parametrize("incline_strides", ([0, 1], [4, 5], [-1, -2]))
def test_incline_to_short(incline_strides):
    """Test gaitseqeunces that start with incline walking.

    Scenario:
        - 2 inline strides at the beginning/center/end
        - No left right foot

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - LevelWalkingCriteria

    Outcome:
        - Single WB that includes the 2 incline strides

    """
    wb_start_time = 0
    n_strides = 10
    strides = [window(wb_start_time + i, wb_start_time + i + 1, parameter={"elevation": 0}) for i in range(n_strides)]
    for s in incline_strides:
        strides[s]["parameter"]["elevation"] = 1.0
    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("level_walk", LevelWalkingCriteria(0.5, max_non_level_strides=3)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides)

    assert len(wba.excluded_wb_list_) == 0
    assert len(wba.excluded_stride_list_) == 0
    assert len(wba.exclusion_reasons_) == 0
    assert len(wba.stride_exclusion_reasons_) == 0

    assert wba.wb_list_[0]["end"] == wb_start_time + n_strides
    assert wba.wb_list_[0]["start"] == wb_start_time
    assert wba.wb_list_[0]["strideList"] == strides
    assert len(wba.wb_list_) == 1


@pytest.mark.parametrize("incline_strides", ([0, 1, 2, 3], [-1, -2, -3, -4]))
def test_incline_start_end_long(incline_strides):
    """Test gaitseqeunces that start with incline walking.

    Scenario:
        - 4 inline strides at the beginning (or end)
        - No left right foot

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - LevelWalkingCriteria

    Outcome:
        - Single WB that starts after/ends before the incline strides
        - All incline strides are discarded
    """
    wb_start_time = 0
    n_strides = 10
    strides = [window(wb_start_time + i, wb_start_time + i + 1, parameter={"elevation": 0}) for i in range(n_strides)]
    for s in incline_strides:
        strides[s]["parameter"]["elevation"] = 1.0
    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("level_walk", LevelWalkingCriteria(0.5, max_non_level_strides=3)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides)

    assert len(wba.excluded_stride_list_) == 4
    assert len(wba.stride_exclusion_reasons_) == 4

    assert len(wba.wb_list_) == 1

    if incline_strides[-1] > 0:
        assert len(wba.excluded_wb_list_) == 0
        assert len(wba.exclusion_reasons_) == 0
        assert wba.wb_list_[0]["end"] == wb_start_time + n_strides
        assert wba.wb_list_[0]["start"] == wb_start_time + len(incline_strides)
        assert wba.wb_list_[0]["strideList"] == strides[4:]
    else:
        assert len(wba.excluded_wb_list_) == 1
        assert len(wba.exclusion_reasons_) == 1
        assert wba.wb_list_[0]["end"] == wb_start_time + n_strides - len(incline_strides)
        assert wba.wb_list_[0]["start"] == wb_start_time
        assert wba.wb_list_[0]["strideList"] == strides[:-4]


def test_incline_center_long():
    """Test gaitseqeunces that have a long incline period in the center.

    Scenario:
        - 5 inline strides in the center
        - No left right foot

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - LevelWalkingCriteria

    Outcome:
        - Two WBs (one before, one after the incline period)
        - All incline strides are discarded
    """
    wb_start_time = 0
    n_strides = 15
    strides = [window(wb_start_time + i, wb_start_time + i + 1, parameter={"elevation": 0}) for i in range(n_strides)]
    incline_strides = [6, 7, 8, 9, 10]
    for s in incline_strides:
        strides[s]["parameter"]["elevation"] = 1.0
    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("level_walk", LevelWalkingCriteria(0.5, max_non_level_strides=3)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides)

    assert len(wba.excluded_wb_list_) == 0
    assert len(wba.excluded_stride_list_) == len(incline_strides)
    assert len(wba.exclusion_reasons_) == 0
    assert len(wba.stride_exclusion_reasons_) == len(incline_strides)
    for r in wba.stride_exclusion_reasons_.values():
        assert r == rules[2]

    assert len(wba.wb_list_) == 2
    assert wba.wb_list_[0]["start"] == wb_start_time
    assert wba.wb_list_[0]["end"] == wb_start_time + incline_strides[0]
    assert wba.wb_list_[0]["strideList"] == strides[: incline_strides[0]]

    assert wba.wb_list_[1]["start"] == wb_start_time + strides[incline_strides[-1] + 1]["start"]
    assert wba.wb_list_[1]["end"] == wb_start_time + n_strides
    assert wba.wb_list_[1]["strideList"] == strides[incline_strides[-1] + 1 :]


def test_short_incline_after_break():
    """Test what happens if there are incline strides after a break.

    Scenario:
        - valid WB than break, then 2 incline strides then valid wb
        - No left right

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - LevelWalkingCriteria

    Outcome:
        - 2 WBs: One before the break, the second after the break
        - The second one includes the 2 incline strides
    """
    n_strides_1 = 5
    wb_start_time_1 = 0
    first_stride_list = [
        window(wb_start_time_1 + i, wb_start_time_1 + i + 1, parameter={"elevation": 0}) for i in range(n_strides_1)
    ]
    n_strides_2 = 10
    wb_start_time_2 = wb_start_time_1 + n_strides_1 + 5
    second_stride_list = [
        window(wb_start_time_2 + i, wb_start_time_2 + i + 1, parameter={"elevation": 0}) for i in range(n_strides_2)
    ]
    incline_strides = [0, 1]
    for s in incline_strides:
        second_stride_list[s]["parameter"]["elevation"] = 1.0

    strides = [*first_stride_list, *second_stride_list]
    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("level_walk", LevelWalkingCriteria(0.5, max_non_level_strides=3)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides)

    assert len(wba.excluded_wb_list_) == 0
    assert len(wba.excluded_stride_list_) == 0
    assert len(wba.exclusion_reasons_) == 0
    assert len(wba.stride_exclusion_reasons_) == 0

    assert len(wba.wb_list_) == 2

    assert wba.wb_list_[0]["start"] == wb_start_time_1
    assert wba.wb_list_[0]["end"] == wb_start_time_1 + n_strides_1
    assert wba.wb_list_[0]["strideList"] == first_stride_list

    assert wba.wb_list_[1]["start"] == wb_start_time_2
    assert wba.wb_list_[1]["end"] == wb_start_time_2 + n_strides_2
    assert wba.wb_list_[1]["strideList"] == second_stride_list


@pytest.mark.parametrize("n_incline_start", (2, 4, 1))
def test_short_incline_after_break_incline_before_break(n_incline_start):
    """Test what happens if there are incline strides before and after a break.

    Scenario:
        - some incline strides (invalid WB) than break, then 2 incline strides then valid wb
        - No left right

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - LevelWalkingCriteria

    Outcome:
        - 1 WB after the break
        - The WB includes the 2 incline strides
    """
    n_strides_1 = n_incline_start
    wb_start_time_1 = 0
    first_stride_list = [
        window(wb_start_time_1 + i, wb_start_time_1 + i + 1, parameter={"elevation": 1.0}) for i in range(n_strides_1)
    ]
    n_strides_2 = 10
    wb_start_time_2 = wb_start_time_1 + n_strides_1 + 5
    second_stride_list = [
        window(wb_start_time_2 + i, wb_start_time_2 + i + 1, parameter={"elevation": 0}) for i in range(n_strides_2)
    ]
    incline_strides = [0, 1]
    for s in incline_strides:
        second_stride_list[s]["parameter"]["elevation"] = 1.0

    strides = [*first_stride_list, *second_stride_list]
    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("level_walk", LevelWalkingCriteria(0.5, max_non_level_strides=3)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides)

    assert len(wba.excluded_wb_list_) == 1
    assert len(wba.excluded_stride_list_) == n_incline_start
    assert len(wba.exclusion_reasons_) == 1
    assert len(wba.stride_exclusion_reasons_) == n_incline_start

    assert len(wba.wb_list_) == 1

    assert wba.wb_list_[0]["start"] == wb_start_time_2
    assert wba.wb_list_[0]["end"] == wb_start_time_2 + n_strides_2
    assert wba.wb_list_[0]["strideList"] == second_stride_list


@pytest.mark.parametrize("n_incline_start", (2, 4, 1))
def test_long_incline_after_break_incline_before_break(n_incline_start):
    """Test what happens if there are incline strides before and after (many) a break.

    Scenario:
        - some incline strides (invalid WB) than break, then 4 incline strides then valid wb
        - No left right

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - LevelWalkingCriteria

    Outcome:
        - 1 WB after the break
        - The WB includes the 2 incline strides
    """
    n_strides_1 = n_incline_start
    wb_start_time_1 = 0
    first_stride_list = [
        window(wb_start_time_1 + i, wb_start_time_1 + i + 1, parameter={"elevation": 1.0}) for i in range(n_strides_1)
    ]
    n_strides_2 = 10
    wb_start_time_2 = wb_start_time_1 + n_strides_1 + 5
    second_stride_list = [
        window(wb_start_time_2 + i, wb_start_time_2 + i + 1, parameter={"elevation": 0}) for i in range(n_strides_2)
    ]
    incline_strides = [0, 1, 2, 3]
    for s in incline_strides:
        second_stride_list[s]["parameter"]["elevation"] = 1.0

    strides = [*first_stride_list, *second_stride_list]
    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("level_walk", LevelWalkingCriteria(0.5, max_non_level_strides=3)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides)

    assert len(wba.excluded_wb_list_) == 1
    assert len(wba.excluded_stride_list_) == n_incline_start + len(incline_strides)
    assert len(wba.exclusion_reasons_) == 1
    assert len(wba.stride_exclusion_reasons_) == n_incline_start + len(incline_strides)

    assert len(wba.wb_list_) == 1

    assert wba.wb_list_[0]["start"] == wb_start_time_2 + len(incline_strides)
    assert wba.wb_list_[0]["end"] == wb_start_time_2 + n_strides_2
    assert wba.wb_list_[0]["strideList"] == second_stride_list[len(incline_strides) :]


def test_long_incline_after_break():
    """Test what happens if there are incline strides after a break.

    Scenario:
        - valid WB than break, then 4 incline strides then valid wb
        - No left right

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - LevelWalkingCriteria

    Outcome:
        - 2 WBs: One before the break, the second after the break
        - The second one should start AFTER the incline strides
    """
    n_strides_1 = 5
    wb_start_time_1 = 0
    first_stride_list = [
        window(wb_start_time_1 + i, wb_start_time_1 + i + 1, parameter={"elevation": 0}) for i in range(n_strides_1)
    ]
    n_strides_2 = 10
    wb_start_time_2 = wb_start_time_1 + n_strides_1 + 5
    second_stride_list = [
        window(wb_start_time_2 + i, wb_start_time_2 + i + 1, parameter={"elevation": 0}) for i in range(n_strides_2)
    ]
    incline_strides = [0, 1, 2, 3]
    for s in incline_strides:
        second_stride_list[s]["parameter"]["elevation"] = 1.0

    strides = [*first_stride_list, *second_stride_list]
    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("level_walk", LevelWalkingCriteria(0.5, max_non_level_strides=3)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides)

    assert len(wba.excluded_wb_list_) == 0
    assert len(wba.excluded_stride_list_) == len(incline_strides)
    assert len(wba.exclusion_reasons_) == 0
    assert len(wba.stride_exclusion_reasons_) == len(incline_strides)

    assert len(wba.wb_list_) == 2

    assert wba.wb_list_[0]["start"] == wb_start_time_1
    assert wba.wb_list_[0]["end"] == wb_start_time_1 + n_strides_1
    assert wba.wb_list_[0]["strideList"] == first_stride_list

    assert wba.wb_list_[1]["start"] == wb_start_time_2 + len(incline_strides)
    assert wba.wb_list_[1]["end"] == wb_start_time_2 + n_strides_2
    assert wba.wb_list_[1]["strideList"] == second_stride_list[4:]


def test_turns_between_wbs():
    """Test that turns correctly interrupt and restart WBs.

    Scenario:
        - long gait sequence with 2 turns in the center
        - no left/right

    Rules:
        - BreakCriterium
        - MinStride Inclusion Rule
        - TurnCriteria

    Outcome:
        - 3 WBs
        - All strides "touched" by a turn are discarded
    """
    n_strides = 20
    wb_start_time = 0
    strides = [window(wb_start_time + i, wb_start_time + i + 1) for i in range(n_strides)]
    turns = {
        "name": "turn",
        "events": [window(5, 8, parameter={"angle": 180}), window(12, 15, parameter={"angle": 180})],
    }

    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("turns", TurnAngleCriteria(min_turn_angle=150, min_turn_rate=50)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides, event_list=[turns])

    assert len(wba.excluded_wb_list_) == 0
    assert len(wba.excluded_stride_list_) == 6
    assert len(wba.exclusion_reasons_) == 0
    assert len(wba.stride_exclusion_reasons_) == 6

    for r in wba.stride_exclusion_reasons_.values():
        assert r == rules[2]

    assert len(wba.wb_list_) == 3

    assert wba.wb_list_[0]["start"] == wb_start_time
    assert wba.wb_list_[0]["end"] == 5
    assert wba.wb_list_[0]["strideList"] == strides[:5]

    assert wba.wb_list_[1]["start"] == 8
    assert wba.wb_list_[1]["end"] == 12
    assert wba.wb_list_[1]["strideList"] == strides[8:12]

    assert wba.wb_list_[2]["start"] == 15
    assert wba.wb_list_[2]["end"] == n_strides
    assert wba.wb_list_[2]["strideList"] == strides[15:]


def test_combined_turn_and_break():
    """Test what happens if a turn and a break overlap.

    Scenario:
        - A long WB with two turn and two breaks
        - The two turns and the two breaks overlap (e.g. the system was not able to detect rutning strides)
        - The first stride after the breaks have partial overlap with the turns.
    """
    n_strides = 25
    wb_start_time = 0
    strides = [window(wb_start_time + i, wb_start_time + i + 1) for i in range(n_strides)]
    del strides[13:17]
    del strides[4:7]
    turns = {
        "name": "turn",
        "events": [window(5, 7.5, parameter={"angle": 180}), window(15, 17.5, parameter={"angle": 180})],
    }

    rules = [
        ("break", MaxBreakCriteria(3)),
        ("n_strides", NStridesCriteria(4)),
        ("turns", TurnAngleCriteria(min_turn_angle=150, min_turn_rate=50)),
    ]

    wba = WBAssembly(rules)
    wba.assemble(strides, event_list=[turns])

    assert len(wba.excluded_wb_list_) == 0
    assert len(wba.excluded_stride_list_) == 2
    assert len(wba.exclusion_reasons_) == 0
    assert len(wba.stride_exclusion_reasons_) == 2

    for r in wba.stride_exclusion_reasons_.values():
        assert r == rules[2]

    assert len(wba.wb_list_) == 3

    assert wba.wb_list_[0]["start"] == wb_start_time
    assert wba.wb_list_[0]["end"] == 4
    assert wba.wb_list_[0]["strideList"] == strides[:4]

    assert wba.wb_list_[1]["start"] == 8
    assert wba.wb_list_[1]["end"] == 13
    assert wba.wb_list_[1]["strideList"] == strides[5:10]

    assert wba.wb_list_[2]["start"] == 18
    assert wba.wb_list_[2]["end"] == n_strides
    assert wba.wb_list_[2]["strideList"] == strides[11:]

@AKuederle
Copy link
Contributor Author

And more tests:

class TestEventTerminationCriteria(BaseTestCriteriaTermination):
    criteria_class = EventTerminationCriteria
    defaults = {"event_type": "event1", "termination_mode": "ongoing", "comment": "comment"}

    test_cases = (
        (((0, 1), (1.5, 2.5)), (1.2, 1.3)),  # 1 Started and stopped between strides
        (
            ((0, 1), (1.5, 2.5)),
            (0.5, 2),
        ),  # 2 Started in last stride and stopped in current strides
        (
            ((0, 1), (1.5, 2.5)),
            (1, 1.2),
        ),  # 3 Started at end of last stride and stopped between strides
        (((0, 1), (1.5, 2.5)), (1.7, 3)),  # 4 Started in current stride and ended after
        (((0, 1), (1.5, 2.5)), (1.7, 2)),  # 5 Started and ended in current stride
        (((0, 1), (1.5, 2.5)), (0.2, 0.8)),  # 6 Started and ended in last stride
        (
            ((1, 2), (2.5, 3.5)),
            (0.5, 4.5),
        ),  # 7 Started before last stride and stopped after current strides
        # Cases that are not checked because they occur before the last stride
        (((1, 2), (2.5, 3.5)), (0.5, 0.8)),  # 8 Started and stopped before last stride
        # Cases that are not checked because they occur after the current stride
        (((0, 1), (1.5, 2.5)), (3, 3.5)),  # 9 Started and stopped after current stride
        # Edge cases
        (
            ((0, 1), (1.5, 2.5)),
            (1, 2),
        ),  # 10 Started at end of last stride and stopped in current strides
        (
            ((0, 1), (1.5, 2.5)),
            (1, 1.2),
        ),  # 11 Started at end of last stride and stopped between strides
        (((0, 1), (1.5, 2.5)), (0, 1)),  # 12 exactly the last stride
        (((0, 1), (1.5, 2.5)), (1.5, 2.5)),  # 13 exactly the current stride
        (((0, 1), (1.5, 2.5)), (0, 2.5)),  # 14 exactly the both strides
        (((0, 1), (1.5, 2.5)), (2.5, 3)),  # 15 start at end of last stride
    )

    start_results = (
        False,  # 1
        False,  # 2
        False,  # 3
        False,  # 4
        False,  # 5
        False,  # 6
        True,  # 7
        True,  # 8
        True,  # 9
        False,  # 10
        False,  # 11
        False,  # 12
        False,  # 13
        False,  # 14
        True,  # 15
    )

    end_results = (
        False,  # 1
        False,  # 2
        False,  # 3
        True,  # 4
        False,  # 5
        False,  # 6
        True,  # 7
        True,  # 8
        True,  # 9
        False,  # 10
        False,  # 11
        False,  # 12
        True,  # 13
        True,  # 14
        True,  # 15
    )

    both_results = ~(~np.array(start_results) | ~np.array(end_results))

    ongoing_results = (
        False,  # 1
        False,  # 2
        False,  # 3
        False,  # 4
        False,  # 5
        False,  # 6
        False,  # 7
        True,  # 8
        True,  # 9
        False,  # 10
        False,  # 11
        False,  # 12
        False,  # 13
        False,  # 14
        True,  # 15
    )

    @pytest.fixture(
        params=(
            (
                {
                    "event_type": "event1",
                    "termination_mode": "ongoing",
                    "comment": "comment",
                },
                {
                    "event_type": "event1",
                    "termination_mode": "ongoing",
                    "comment": "comment",
                },
            ),
            (
                {"event_type": "event1"},
                {"event_type": "event1", "termination_mode": "ongoing"},
            ),
            (
                {"event_type": "event1", "termination_mode": "start"},
                {"event_type": "event1", "termination_mode": "start"},
            ),
            (
                {"event_type": "event1", "termination_mode": "end"},
                {"event_type": "event1", "termination_mode": "end"},
            ),
            (
                {"event_type": "event1", "termination_mode": "both"},
                {"event_type": "event1", "termination_mode": "both"},
            ),
            ({"event_type": "event1", "termination_mode": "something_wrong"}, ValueError),
        )
    )
    def init_data(self, request):
        return request.param

    @staticmethod
    def _test_check(strides_events, result, termination_mode):
        stride_start_stop, event_start_stop = strides_events
        c = EventTerminationCriteria(event_type="event1", termination_mode=termination_mode)
        stride_list = [window(*stride_start_stop[0]), window(*stride_start_stop[1])]
        events = [window(*event_start_stop)]
        event_list = [{"name": "event1", "events": events}]

        out = c.check_wb_start_end(
            stride_list, event_list=event_list, original_start=0, current_start=0, current_end=len(stride_list) - 1
        )

        result = len(stride_list) - 2 if bool(result) is False else None
        assert out[1] == result
        assert out[0] is None

    @pytest.mark.parametrize(("strides_events", "result"), (zip(test_cases, start_results)))
    def test_check_start_termination(self, strides_events, result):
        self._test_check(strides_events, result, "start")

    # TODO: Check moving start
    # TODO: Check event before the first stride

    @pytest.mark.parametrize(("strides_events", "result"), (zip(test_cases, end_results)))
    def test_check_end_termination(self, strides_events, result):
        self._test_check(strides_events, result, "end")

    @pytest.mark.parametrize(("strides_events", "result"), (zip(test_cases, both_results)))
    def test_check_both_termination(self, strides_events, result):
        self._test_check(strides_events, result, "both")

    @pytest.mark.parametrize(("strides_events", "result"), (zip(test_cases, ongoing_results)))
    def test_check_ongoing_termination(self, strides_events, result):
        self._test_check(strides_events, result, "ongoing")


class TestMaxTurnAngleCriteria(BaseTestCriteriaTermination):
    criteria_class = TurnAngleCriteria
    defaults = {"min_turn_angle": 0, "max_turn_angle": 1, "comment": "comment"}

    @pytest.fixture(
        params=(
            (
                {
                    "min_turn_angle": -1,
                    "max_turn_angle": 0,
                    "min_turn_rate": 0,
                    "max_turn_rate": 1,
                    "name": "name",
                    "comment": "comment",
                },
                {
                    "min_turn_angle": -1,
                    "max_turn_angle": 0,
                    "min_turn_rate": 0,
                    "max_turn_rate": 1,
                    "name": "name",
                    "comment": "comment",
                },
            ),
            (
                {"min_turn_angle": None, "max_turn_angle": 0},
                {"min_turn_angle": -np.inf, "max_turn_angle": 0},
            ),
            (
                {"min_turn_angle": 0, "max_turn_angle": None},
                {"min_turn_angle": 0, "max_turn_angle": np.inf},
            ),
            ({"min_turn_angle": 0, "max_turn_angle": 0}, ValueError),
            ({"min_turn_angle": 2, "max_turn_angle": 1}, ValueError),
            (
                {"min_turn_rate": None, "max_turn_rate": 0},
                {"min_turn_rate": -np.inf, "max_turn_rate": 0},
            ),
            (
                {"min_turn_rate": 0, "max_turn_rate": None},
                {"min_turn_rate": 0, "max_turn_rate": np.inf},
            ),
            ({"min_turn_rate": 0, "max_turn_rate": 0}, ValueError),
            ({"min_turn_rate": 2, "max_turn_rate": 1}, ValueError),
            (
                {},
                {
                    "min_turn_angle": -np.inf,
                    "max_turn_angle": np.inf,
                    "min_turn_rate": -np.inf,
                    "max_turn_rate": np.inf,
                    "name": None,
                    "comment": None,
                },
            ),
        )
    )
    def init_data(self, request):
        return request.param

    @pytest.mark.parametrize(
        ("strides_events", "result"),
        (
            zip(
                TestEventTerminationCriteria.test_cases,
                TestEventTerminationCriteria.ongoing_results,
            )
        ),
    )
    def test_check_ongoing(self, strides_events, result):
        TestEventTerminationCriteria._test_check(strides_events, result, "ongoing")

    def test_event_filter_angle(self):
        a = 10
        events_include = [window(0, 1, parameter={"angle": a}) for _ in range(10)]
        events_exclude = [window(0, 1, parameter={"angle": a + 10}) for _ in range(10)]
        events = [{"name": "turn", "events": [*events_include, *events_exclude]}]

        c = TurnAngleCriteria(a - 5, a + 5)
        assert c.filter_events(events) == events_include

        c = TurnAngleCriteria(a - 5, a + 20)
        assert c.filter_events(events) == events[0]["events"]

        c = TurnAngleCriteria(a - 10, a - 5)
        assert c.filter_events(events) == []

        c = TurnAngleCriteria(a + 11, a + 20)
        assert c.filter_events(events) == []

    def test_event_filter_rate(self):
        a = 10
        events_include = [window(0, 10, parameter={"angle": a}) for _ in range(10)]  # -> Turnrate 1
        events_exclude = [window(0, 5, parameter={"angle": a}) for _ in range(10)]  # -> Turnrate 2
        events = [{"name": "turn", "events": [*events_include, *events_exclude]}]

        c = TurnAngleCriteria(min_turn_rate=0, max_turn_rate=1.5)
        assert c.filter_events(events) == events_include

        c = TurnAngleCriteria()
        assert c.filter_events(events) == events[0]["events"]

        c = TurnAngleCriteria(min_turn_rate=0, max_turn_rate=1)
        assert c.filter_events(events) == events_include

        c = TurnAngleCriteria(min_turn_rate=1, max_turn_rate=1.5)
        assert c.filter_events(events) == events_include


class TestLevelWalkingCriteria(BaseTestCriteriaTermination):
    criteria_class = LevelWalkingCriteria
    defaults = {
        "max_non_level_strides": 1,
        "max_non_level_strides_left": 2,
        "max_non_level_strides_right": 3,
        "level_walking_threshold": 4,
        "name": "name",
        "comment": "comment",
    }

    @pytest.fixture(
        params=(
            (
                {
                    "max_non_level_strides": 1,
                    "max_non_level_strides_left": 2,
                    "max_non_level_strides_right": 3,
                    "level_walking_threshold": 4,
                    "name": "name",
                    "comment": "comment",
                },
                {
                    "max_non_level_strides": 1,
                    "max_non_level_strides_left": 2,
                    "max_non_level_strides_right": 3,
                    "level_walking_threshold": 4,
                    "name": "name",
                    "comment": "comment",
                },
            ),
            (
                {
                    "max_non_level_strides": 1,
                    "max_non_level_strides_left": 2,
                    "max_non_level_strides_right": 3,
                    "level_walking_threshold": -3,
                    "name": "name",
                    "comment": "comment",
                },
                ValueError,
            ),
        )
    )
    def init_data(self, request):
        return request.param

    @pytest.mark.skip(reason="Not applicable for this criteria")
    def test_single_stride(self, naive_stride_list, naive_event_list):
        pass

    @pytest.mark.parametrize(
        ("invalid_strides", "expected_end"),
        (
            ([4, 5, 6], 4),
            ([4, 5], 20),
            ([4, 5, 6, 7], 4),
            ([4, 5, 7, 8], 20),
            ([4, 5, 7, 8, 9], 7),
        ),
    )
    def test_simple_break(self, invalid_strides, expected_end):
        lr = repeat(("left", "right"))
        thres = 2
        strides = [window(start=s, end=s + 1, foot=next(lr), parameter={"elevation": 0}) for s in range(20)]
        for i, s in enumerate(strides):
            if i in invalid_strides:
                s["parameter"]["elevation"] = thres + 1

        # Test directly in combination with the WBA
        wba = WBAssembly([("level_walk", LevelWalkingCriteria(max_non_level_strides=3, level_walking_threshold=thres))])
        wba.assemble(strides)

        assert expected_end == wba.wb_list_[0]["end"]

    def test_combined_break_rules(self):
        """Test multiple delayed rules.

        This test checks what happens if multiple delay rules are active at the same time.
        """
        # First rule fires earlier with higher threshold
        r1 = LevelWalkingCriteria(max_non_level_strides=2, level_walking_threshold=2)
        # Second rule has higher delay but lower threshold
        r2 = LevelWalkingCriteria(max_non_level_strides=4, level_walking_threshold=1)

        test_elevation = [0, 0, 0, 0, 1.5, 1.5, 1.5, 2.5, 2.5, 2.5]
        lr = repeat(("left", "right"))
        strides = [
            window(start=s, end=s + 1, foot=next(lr), parameter={"elevation": test_elevation[s]})
            for s in range(len(test_elevation))
        ]

        # only rule one
        wba = WBAssembly([("r1", r1)])
        wba.assemble(strides)

        assert wba.wb_list_[0]["end"] == 7

        # only rule two
        wba = WBAssembly([("r2", r2)])
        wba.assemble(strides)

        assert wba.wb_list_[0]["end"] == 4

        # combined rules
        wba = WBAssembly([("r1", r1), ("r2", r2)])
        wba.assemble(strides)

        assert wba.wb_list_[0]["end"] == 4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant