Skip to content

Commit

Permalink
Set default parameters
Browse files Browse the repository at this point in the history
summary_extractor, summary_extractor_args
  • Loading branch information
m-martin-j committed Nov 23, 2022
1 parent c18b3d1 commit a7ce08a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def alert_cbck(alert_code, alert_msg):
alert_msg = 'no msg'
print(f'{alert_msg} (code {alert_code})')
c = CDCStream(
alert_callback=alert_cbck,
summary_extractor=dilca_workflow,
summary_extractor_args={'nominal_cols': 'all'},
alert_callback=alert_cbck,
factor_warn=2.0,
factor_change=3.0,
factor_std_extr_forg=0,
Expand Down
18 changes: 11 additions & 7 deletions cdcstream/cdcstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import numpy as np
import pandas as pd

from cdcstream.dilca_wrapper import dilca_workflow


class DriftDetector:
ALERT_NONE = 0
Expand All @@ -42,8 +44,9 @@ class UnsupervisedDriftDetector(DriftDetector):
class CDCStream(UnsupervisedDriftDetector):

def __init__(self,
summary_extractor: Callable, summary_extractor_args: dict,
alert_callback: Callable,
summary_extractor: Callable=dilca_workflow,
summary_extractor_args: dict={'nominal_cols': 'all'},
factor_warn: float=2.0,
factor_change: float=3.0,
factor_std_extr_forg: Union[float, int] =0,
Expand All @@ -63,13 +66,15 @@ def __init__(self,
}
Args:
summary_extractor (Callable): Function for extracting a summary value from a batch of
data. Must accept as first parameter: data; as named parameter: supervised.
summary_extractor_args (dict): Other named parameters (apart from first parameter data
and named parameter supervised) to pass to summary_extractor.
alert_callback (Callable): Function being called after each batch with an alert code of
ALERT_NONE, ALERT_WARN, ALERT_CHANGE. Must accept as arguments: alert_code: int,
alert_msg: str.
summary_extractor (Callable, optional): Function for extracting a summary value from a
batch of data. Must accept as first parameter: data; as named parameter:
supervised. Defaults to cdcstream.dilca_wrapper.dilca_workflow.
summary_extractor_args (dict, optional): Other named parameters (apart from first
parameter data and named parameter supervised) to pass to summary_extractor.
Defaults to {'nominal_cols': 'all'}.
factor_warn (float, optional): Parameter for Chebychev's Inequality for issuing a drift
warning. Must be smaller than or equal to factor_change. Defaults to 2.0.
factor_change (float, optional): Parameter for Chebychev's Inequality for signaling a
Expand Down Expand Up @@ -300,7 +305,6 @@ def log(self) -> pd.DataFrame:


if __name__ == '__main__':
from cdcstream.dilca_wrapper import dilca_workflow
from cdcstream import tools


Expand All @@ -314,9 +318,9 @@ def alert_cbck(alert_code, alert_msg):
alert_msg = 'no msg'
print(f'{alert_msg} (code {alert_code})')
c = CDCStream(
alert_callback=alert_cbck,
summary_extractor=dilca_workflow,
summary_extractor_args={'nominal_cols': 'all'},
alert_callback=alert_cbck,
factor_warn=2.0,
factor_change=3.0,
factor_std_extr_forg=0,
Expand Down

0 comments on commit a7ce08a

Please sign in to comment.