diff --git a/python/interpret-core/interpret/glassbox/_decisiontree.py b/python/interpret-core/interpret/glassbox/_decisiontree.py index dd82b2ca1..a3d491d58 100644 --- a/python/interpret-core/interpret/glassbox/_decisiontree.py +++ b/python/interpret-core/interpret/glassbox/_decisiontree.py @@ -260,12 +260,14 @@ def _model(self): # This method should be overriden return None - def fit(self, X, y): + def fit(self, X, y, sample_weight=None, check_input=True): """Fits model to provided instances. Args: X: Numpy array for training instances. y: Numpy array as training labels. + sample_weight (optional[np.ndarray]): (n_samples,) Sample weights. If None (default), then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are ignored while searching for a split in each node. + check_input (bool): default=True. Allow to bypass several input checking. Don't use this parameter unless you know what you're doing. Returns: Itself. @@ -289,7 +291,7 @@ def fit(self, X, y): ) model = self._model() - model.fit(X, y) + model.fit(X, y, sample_weight=sample_weight, check_input=check_input) unique_val_counts = np.zeros(len(self.feature_names_in_), dtype=np.int64) for col_idx in range(len(self.feature_names_in_)): @@ -571,18 +573,25 @@ def __init__(self, feature_names=None, feature_types=None, max_depth=3, **kwargs def _model(self): return self.sk_model_ - def fit(self, X, y): + def fit(self, X, y, sample_weight=None, check_input=True): """Fits model to provided instances. Args: X: Numpy array for training instances. y: Numpy array as training labels. + sample_weight (optional[np.ndarray]): (n_samples,) Sample weights. If None (default), then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are ignored while searching for a split in each node. + check_input (bool): default=True. Allow to bypass several input checking. Don't use this parameter unless you know what you're doing. Returns: Itself. """ self.sk_model_ = SKRT(max_depth=self.max_depth, **self.kwargs) - return super().fit(X, y) + return super().fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + ) class ClassificationTree(BaseShallowDecisionTree, ClassifierMixin, ExplainerMixin): @@ -607,18 +616,25 @@ def __init__(self, feature_names=None, feature_types=None, max_depth=3, **kwargs def _model(self): return self.sk_model_ - def fit(self, X, y): + def fit(self, X, y, sample_weight=None, check_input=True): """Fits model to provided instances. Args: X: Numpy array for training instances. y: Numpy array as training labels. + sample_weight (optional[np.ndarray]): (n_samples,) Sample weights. If None (default), then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are ignored while searching for a split in each node. + check_input (bool): default=True. Allow to bypass several input checking. Don't use this parameter unless you know what you're doing. Returns: Itself. """ self.sk_model_ = SKDT(max_depth=self.max_depth, **self.kwargs) - return super().fit(X, y) + return super().fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + ) def predict_proba(self, X): """Probability estimates on provided instances.