diff --git a/README.md b/README.md index d5ba68e..72bebc0 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ def alert_cbck(alert_code, alert_msg): if not alert_msg: alert_msg = 'no msg' print(f'{alert_msg} (code {alert_code})') + c = CDCStream( alert_callback=alert_cbck, summary_extractor=dilca_workflow, diff --git a/cdcstream/cdcstream.py b/cdcstream/cdcstream.py index 60d5141..c8a2362 100644 --- a/cdcstream/cdcstream.py +++ b/cdcstream/cdcstream.py @@ -44,7 +44,7 @@ class UnsupervisedDriftDetector(DriftDetector): class CDCStream(UnsupervisedDriftDetector): def __init__(self, - alert_callback: Callable, + alert_callback: Callable=None, summary_extractor: Callable=dilca_workflow, summary_extractor_args: dict={'nominal_cols': 'all'}, factor_warn: float=2.0, @@ -66,9 +66,9 @@ def __init__(self, } Args: - alert_callback (Callable): Function being called after each batch with an alert code of - ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments: alert_code: int, - alert_msg: str. + alert_callback (Callable, optional): Function being called after each batch with an + alert code of ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments: + alert_code: int, alert_msg: str. Defaults to None. summary_extractor (Callable, optional): Function for extracting a summary value from a batch of data. Must accept as first parameter: data; as named parameter: supervised. Defaults to cdcstream.dilca_wrapper.dilca_workflow. @@ -134,6 +134,7 @@ def feed_new_batch(self, batch: pd.DataFrame) -> None: def _cycle_routine(self) -> None: """Main cycle of drift detector. It is being called with each new batch being fed. """ + self._reset() self._extract_current_batch_summary_statistic() self._std_extrema_forgetting() # BEFORE history statistics (especially stds) are being computed self._compute_history_statistics() @@ -152,8 +153,6 @@ def _cleanup_current_cycle(self) -> None: if self.current_change: # React to CHANGE self.reset_history() self.start_cooldown() - self.current_warn = False - self.current_change = False if self.batch_current_summary_statistic is not None: self.history.append(self.batch_current_summary_statistic) @@ -224,6 +223,10 @@ def start_cooldown(self) -> None: """ self._cur_cooldown_cycles = self.cooldown_cycles + def _reset(self): + self.current_warn = False + self.current_change = False + def reset_history(self) -> None: self.history = [] @@ -262,7 +265,10 @@ def _alert(self) -> None: else: alert_code = self.ALERT_NONE alert_msg = self.ALERT_NONE_MSG - self._alert_callback(alert_code, alert_msg) + try: + self._alert_callback(alert_code, alert_msg) + except TypeError: + pass # if user provided no alert_callback def _update_log(self) -> None: log_el = () @@ -317,6 +323,7 @@ def alert_cbck(alert_code, alert_msg): if not alert_msg: alert_msg = 'no msg' print(f'{alert_msg} (code {alert_code})') + c = CDCStream( alert_callback=alert_cbck, summary_extractor=dilca_workflow,