diff --git a/README.md b/README.md index 0c090a9..d5ba68e 100644 --- a/README.md +++ b/README.md @@ -46,9 +46,9 @@ def alert_cbck(alert_code, alert_msg): alert_msg = 'no msg' print(f'{alert_msg} (code {alert_code})') c = CDCStream( + alert_callback=alert_cbck, summary_extractor=dilca_workflow, summary_extractor_args={'nominal_cols': 'all'}, - alert_callback=alert_cbck, factor_warn=2.0, factor_change=3.0, factor_std_extr_forg=0, diff --git a/cdcstream/cdcstream.py b/cdcstream/cdcstream.py index 5f2218f..60d5141 100644 --- a/cdcstream/cdcstream.py +++ b/cdcstream/cdcstream.py @@ -24,6 +24,8 @@ import numpy as np import pandas as pd +from cdcstream.dilca_wrapper import dilca_workflow + class DriftDetector: ALERT_NONE = 0 @@ -42,8 +44,9 @@ class UnsupervisedDriftDetector(DriftDetector): class CDCStream(UnsupervisedDriftDetector): def __init__(self, - summary_extractor: Callable, summary_extractor_args: dict, alert_callback: Callable, + summary_extractor: Callable=dilca_workflow, + summary_extractor_args: dict={'nominal_cols': 'all'}, factor_warn: float=2.0, factor_change: float=3.0, factor_std_extr_forg: Union[float, int] =0, @@ -63,13 +66,15 @@ def __init__(self, } Args: - summary_extractor (Callable): Function for extracting a summary value from a batch of - data. Must accept as first parameter: data; as named parameter: supervised. - summary_extractor_args (dict): Other named parameters (apart from first parameter data - and named parameter supervised) to pass to summary_extractor. alert_callback (Callable): Function being called after each batch with an alert code of ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments: alert_code: int, alert_msg: str. + summary_extractor (Callable, optional): Function for extracting a summary value from a + batch of data. Must accept as first parameter: data; as named parameter: + supervised. Defaults to cdcstream.dilca_wrapper.dilca_workflow. + summary_extractor_args (dict, optional): Other named parameters (apart from first + parameter data and named parameter supervised) to pass to summary_extractor. + Defaults to {'nominal_cols': 'all'}. factor_warn (float, optional): Parameter for Chebychev's Inequality for issuing a drift warning. Must be smaller than or equal to factor_change. Defaults to 2.0. factor_change (float, optional): Parameter for Chebychev's Inequality for signaling a @@ -300,7 +305,6 @@ def log(self) -> pd.DataFrame: if __name__ == '__main__': - from cdcstream.dilca_wrapper import dilca_workflow from cdcstream import tools @@ -314,9 +318,9 @@ def alert_cbck(alert_code, alert_msg): alert_msg = 'no msg' print(f'{alert_msg} (code {alert_code})') c = CDCStream( + alert_callback=alert_cbck, summary_extractor=dilca_workflow, summary_extractor_args={'nominal_cols': 'all'}, - alert_callback=alert_cbck, factor_warn=2.0, factor_change=3.0, factor_std_extr_forg=0,