Skip to content

Commit

Permalink
Change drift-reporting interface
Browse files Browse the repository at this point in the history
allow passing no alert_callback,
pull resetting of current_warn and current_change out of _cleanup_current_cycle and execute separately in _reset method,
call _reset method at the beginning of a new cycle instead of the end - this way user can check for drift herself after each batch being fed
  • Loading branch information
m-martin-j committed Nov 23, 2022
1 parent a7ce08a commit d0b64f2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def alert_cbck(alert_code, alert_msg):
if not alert_msg:
alert_msg = 'no msg'
print(f'{alert_msg} (code {alert_code})')
c = CDCStream(
alert_callback=alert_cbck,
summary_extractor=dilca_workflow,
Expand Down
21 changes: 14 additions & 7 deletions cdcstream/cdcstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class UnsupervisedDriftDetector(DriftDetector):
class CDCStream(UnsupervisedDriftDetector):

def __init__(self,
alert_callback: Callable,
alert_callback: Callable=None,
summary_extractor: Callable=dilca_workflow,
summary_extractor_args: dict={'nominal_cols': 'all'},
factor_warn: float=2.0,
Expand All @@ -66,9 +66,9 @@ def __init__(self,
}
Args:
alert_callback (Callable): Function being called after each batch with an alert code of
ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments: alert_code: int,
alert_msg: str.
alert_callback (Callable, optional): Function being called after each batch with an
alert code of ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments:
alert_code: int, alert_msg: str. Defaults to None.
summary_extractor (Callable, optional): Function for extracting a summary value from a
batch of data. Must accept as first parameter: data; as named parameter:
supervised. Defaults to cdcstream.dilca_wrapper.dilca_workflow.
Expand Down Expand Up @@ -134,6 +134,7 @@ def feed_new_batch(self, batch: pd.DataFrame) -> None:
def _cycle_routine(self) -> None:
"""Main cycle of drift detector. It is being called with each new batch being fed.
"""
self._reset()
self._extract_current_batch_summary_statistic()
self._std_extrema_forgetting() # BEFORE history statistics (especially stds) are being computed
self._compute_history_statistics()
Expand All @@ -152,8 +153,6 @@ def _cleanup_current_cycle(self) -> None:
if self.current_change: # React to CHANGE
self.reset_history()
self.start_cooldown()
self.current_warn = False
self.current_change = False

if self.batch_current_summary_statistic is not None:
self.history.append(self.batch_current_summary_statistic)
Expand Down Expand Up @@ -224,6 +223,10 @@ def start_cooldown(self) -> None:
"""
self._cur_cooldown_cycles = self.cooldown_cycles

def _reset(self):
self.current_warn = False
self.current_change = False

def reset_history(self) -> None:
self.history = []

Expand Down Expand Up @@ -262,7 +265,10 @@ def _alert(self) -> None:
else:
alert_code = self.ALERT_NONE
alert_msg = self.ALERT_NONE_MSG
self._alert_callback(alert_code, alert_msg)
try:
self._alert_callback(alert_code, alert_msg)
except TypeError:
pass # if user provided no alert_callback

def _update_log(self) -> None:
log_el = ()
Expand Down Expand Up @@ -317,6 +323,7 @@ def alert_cbck(alert_code, alert_msg):
if not alert_msg:
alert_msg = 'no msg'
print(f'{alert_msg} (code {alert_code})')

c = CDCStream(
alert_callback=alert_cbck,
summary_extractor=dilca_workflow,
Expand Down

0 comments on commit d0b64f2

Please sign in to comment.