diff --git a/python/interpret-core/interpret/glassbox/ebm/test/test_ebm.py b/python/interpret-core/interpret/glassbox/ebm/test/test_ebm.py index 0460cf7be..e57ce7ea5 100644 --- a/python/interpret-core/interpret/glassbox/ebm/test/test_ebm.py +++ b/python/interpret-core/interpret/glassbox/ebm/test/test_ebm.py @@ -58,21 +58,6 @@ def test_ebm_synthetic_multiclass_pairwise(): clf.fit(X, y) -@pytest.mark.slow -def test_ebm_multiclass(): - data = iris_classification() - X_train = data["train"]["X"] - y_train = data["train"]["y"] - - X_test = data["test"]["X"] - y_test = data["test"]["y"] - - clf = ExplainableBoostingClassifier() - clf.fit(X_train, y_train) - - assert accuracy_score(y_test, clf.predict(X_test)) > 0.9 - - def test_ebm_synthetic_pairwise(): a = np.random.randint(low=0, high=50, size=10000) b = np.random.randint(low=0, high=20, size=10000) @@ -152,11 +137,27 @@ def test_ebm_synthetic_classfication(): valid_ebm(clf) +def _smoke_test_explanations(global_exp, local_exp, port): + from .... import preserve, show, shutdown_show_server, set_show_addr + + set_show_addr(("127.0.0.1", port)) + + # Smoke test: should run without crashing. + preserve(global_exp) + preserve(local_exp) + show(global_exp) + show(local_exp) + + # Check all features for global (including interactions). + for selector_key in global_exp.selector[global_exp.selector.columns[0]]: + preserve(global_exp, selector_key) + + shutdown_show_server() + + @pytest.mark.visual @pytest.mark.slow def test_ebm_adult(): - from .... import preserve, show, shutdown_show_server, set_show_addr - data = adult_classification() X = data["full"]["X"] y = data["full"]["y"] @@ -175,18 +176,28 @@ def test_ebm_adult(): valid_ebm(clf) - set_show_addr(("127.0.0.1", 6000)) global_exp = clf.explain_global() local_exp = clf.explain_local(X[:5, :], y[:5]) - # Smoke test: should run without crashing. - preserve(global_exp) - preserve(local_exp) - show(global_exp) - show(local_exp) + _smoke_test_explanations(global_exp, local_exp, 6000) - # Check all features for global (including interactions). - for selector_key in global_exp.selector[global_exp.selector.columns[0]]: - preserve(global_exp, selector_key) - shutdown_show_server() +@pytest.mark.visual +@pytest.mark.slow +def test_ebm_iris(): + data = iris_classification() + X_train = data["train"]["X"] + y_train = data["train"]["y"] + + X_test = data["test"]["X"] + y_test = data["test"]["y"] + + clf = ExplainableBoostingClassifier() + clf.fit(X_train, y_train) + + assert accuracy_score(y_test, clf.predict(X_test)) > 0.9 + + global_exp = clf.explain_global() + local_exp = clf.explain_local(X_test, y_test) + + _smoke_test_explanations(global_exp, local_exp, 6001) diff --git a/python/interpret-core/interpret/utils/environment.py b/python/interpret-core/interpret/utils/environment.py index c368b50d1..6d2e0bf31 100644 --- a/python/interpret-core/interpret/utils/environment.py +++ b/python/interpret-core/interpret/utils/environment.py @@ -25,7 +25,7 @@ def _detect_ipython(): from IPython import get_ipython return get_ipython() is not None - except NameError: + except NameError: # pragma: no cover return False @@ -48,7 +48,7 @@ def _detect_ipython_zmq(): return False # Terminal running IPython else: return False # Other type (?) - except NameError: + except NameError: # pragma: no cover return False # Probably standard Python interpreter diff --git a/python/interpret-core/interpret/utils/test/test_environment.py b/python/interpret-core/interpret/utils/test/test_environment.py index 1b238ba2d..aad39190f 100644 --- a/python/interpret-core/interpret/utils/test/test_environment.py +++ b/python/interpret-core/interpret/utils/test/test_environment.py @@ -5,6 +5,12 @@ def test_environment_detector(): + # Default detector = EnvironmentDetector() envs = detector.detect() assert len(envs) == 0 + + # Check if assertion succeeds + detector.checks["always_true"] = lambda: True + envs = detector.detect() + assert len(envs) == 1 and envs[0] == "always_true"