From d0b64f22a3418738530ef83b68889492d3e34666 Mon Sep 17 00:00:00 2001 From: Martin Trat Date: Wed, 23 Nov 2022 17:48:31 +0100 Subject: [PATCH] Change drift-reporting interface allow passing no alert_callback, pull resetting of current_warn and current_change out of _cleanup_current_cycle and execute separately in _reset method, call _reset method at the beginning of a new cycle instead of the end - this way user can check for drift herself after each batch being fed --- README.md | 1 + cdcstream/cdcstream.py | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d5ba68e..72bebc0 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ def alert_cbck(alert_code, alert_msg): if not alert_msg: alert_msg = 'no msg' print(f'{alert_msg} (code {alert_code})') + c = CDCStream( alert_callback=alert_cbck, summary_extractor=dilca_workflow, diff --git a/cdcstream/cdcstream.py b/cdcstream/cdcstream.py index 60d5141..c8a2362 100644 --- a/cdcstream/cdcstream.py +++ b/cdcstream/cdcstream.py @@ -44,7 +44,7 @@ class UnsupervisedDriftDetector(DriftDetector): class CDCStream(UnsupervisedDriftDetector): def __init__(self, - alert_callback: Callable, + alert_callback: Callable=None, summary_extractor: Callable=dilca_workflow, summary_extractor_args: dict={'nominal_cols': 'all'}, factor_warn: float=2.0, @@ -66,9 +66,9 @@ def __init__(self, } Args: - alert_callback (Callable): Function being called after each batch with an alert code of - ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments: alert_code: int, - alert_msg: str. + alert_callback (Callable, optional): Function being called after each batch with an + alert code of ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments: + alert_code: int, alert_msg: str. Defaults to None. summary_extractor (Callable, optional): Function for extracting a summary value from a batch of data. Must accept as first parameter: data; as named parameter: supervised. Defaults to cdcstream.dilca_wrapper.dilca_workflow. @@ -134,6 +134,7 @@ def feed_new_batch(self, batch: pd.DataFrame) -> None: def _cycle_routine(self) -> None: """Main cycle of drift detector. It is being called with each new batch being fed. """ + self._reset() self._extract_current_batch_summary_statistic() self._std_extrema_forgetting() # BEFORE history statistics (especially stds) are being computed self._compute_history_statistics() @@ -152,8 +153,6 @@ def _cleanup_current_cycle(self) -> None: if self.current_change: # React to CHANGE self.reset_history() self.start_cooldown() - self.current_warn = False - self.current_change = False if self.batch_current_summary_statistic is not None: self.history.append(self.batch_current_summary_statistic) @@ -224,6 +223,10 @@ def start_cooldown(self) -> None: """ self._cur_cooldown_cycles = self.cooldown_cycles + def _reset(self): + self.current_warn = False + self.current_change = False + def reset_history(self) -> None: self.history = [] @@ -262,7 +265,10 @@ def _alert(self) -> None: else: alert_code = self.ALERT_NONE alert_msg = self.ALERT_NONE_MSG - self._alert_callback(alert_code, alert_msg) + try: + self._alert_callback(alert_code, alert_msg) + except TypeError: + pass # if user provided no alert_callback def _update_log(self) -> None: log_el = () @@ -317,6 +323,7 @@ def alert_cbck(alert_code, alert_msg): if not alert_msg: alert_msg = 'no msg' print(f'{alert_msg} (code {alert_code})') + c = CDCStream( alert_callback=alert_cbck, summary_extractor=dilca_workflow,