From a5f11000c9cb0a39ee2651a2a260a92bffbbb1db Mon Sep 17 00:00:00 2001 From: Fabian Degen Date: Thu, 24 Oct 2024 15:57:43 +0200 Subject: [PATCH] add mono-classification with Shap Signed-off-by: Fabian Degen --- .../interpret/utils/_unify_predict.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/python/interpret-core/interpret/utils/_unify_predict.py b/python/interpret-core/interpret/utils/_unify_predict.py index 7cffc0686..0fcc6e784 100644 --- a/python/interpret-core/interpret/utils/_unify_predict.py +++ b/python/interpret-core/interpret/utils/_unify_predict.py @@ -3,6 +3,7 @@ import logging +import numpy as np from sklearn.base import is_classifier, is_regressor from ._clean_simple import clean_dimensions @@ -47,6 +48,30 @@ def determine_classes(model, data, n_samples): n_classes = 1 else: n_classes = preds.shape[1] + + if len(classes) == 1: + # for single-class problems, treat it as binary classification + # where the second class probability is always 0 + n_classes = 2 + orig_model = model + original_class = classes[0] + + print( + f"Warning: Model was trained on single-class data. Model will always predict class {original_class}." + ) + + def mono_classification_model(data): + preds = orig_model(data) + if preds.ndim == 1: + preds = preds.reshape(-1, 1) + # add zero probabilities for the synthetic class + return np.hstack([preds, np.zeros_like(preds)]) + + model = mono_classification_model + # keep original class and add any different value as synthetic class + synthetic_class = "other" if original_class != "other" else "synthetic" + classes = np.array([original_class, synthetic_class]) + if n_classes != len(classes): msg = "class number mismatch" _log.error(msg)