Skip to content

Commit

Permalink
add mono-classification with Shap (#582)
Browse files Browse the repository at this point in the history
Signed-off-by: Fabian Degen <[email protected]>
Co-authored-by: Fabian Degen <[email protected]>
  • Loading branch information
degenfabian and Fabian Degen authored Oct 24, 2024
1 parent 0b5034f commit 90734ec
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/interpret-core/interpret/utils/_unify_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging

import numpy as np
from sklearn.base import is_classifier, is_regressor

from ._clean_simple import clean_dimensions
Expand Down Expand Up @@ -47,6 +48,30 @@ def determine_classes(model, data, n_samples):
n_classes = 1
else:
n_classes = preds.shape[1]

if len(classes) == 1:
# for single-class problems, treat it as binary classification
# where the second class probability is always 0
n_classes = 2
orig_model = model
original_class = classes[0]

print(
f"Warning: Model was trained on single-class data. Model will always predict class {original_class}."
)

def mono_classification_model(data):
preds = orig_model(data)
if preds.ndim == 1:
preds = preds.reshape(-1, 1)
# add zero probabilities for the synthetic class
return np.hstack([preds, np.zeros_like(preds)])

model = mono_classification_model
# keep original class and add any different value as synthetic class
synthetic_class = "other" if original_class != "other" else "synthetic"
classes = np.array([original_class, synthetic_class])

if n_classes != len(classes):
msg = "class number mismatch"
_log.error(msg)
Expand Down

0 comments on commit 90734ec

Please sign in to comment.