Skip to content

Commit

Permalink
More tests around environment and EBM.
Browse files Browse the repository at this point in the history
  • Loading branch information
interpret-ml committed Sep 18, 2019
1 parent 4e47d45 commit 67da4f4
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
65 changes: 38 additions & 27 deletions python/interpret-core/interpret/glassbox/ebm/test/test_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,6 @@ def test_ebm_synthetic_multiclass_pairwise():
clf.fit(X, y)


@pytest.mark.slow
def test_ebm_multiclass():
data = iris_classification()
X_train = data["train"]["X"]
y_train = data["train"]["y"]

X_test = data["test"]["X"]
y_test = data["test"]["y"]

clf = ExplainableBoostingClassifier()
clf.fit(X_train, y_train)

assert accuracy_score(y_test, clf.predict(X_test)) > 0.9


def test_ebm_synthetic_pairwise():
a = np.random.randint(low=0, high=50, size=10000)
b = np.random.randint(low=0, high=20, size=10000)
Expand Down Expand Up @@ -152,11 +137,27 @@ def test_ebm_synthetic_classfication():
valid_ebm(clf)


def _smoke_test_explanations(global_exp, local_exp, port):
from .... import preserve, show, shutdown_show_server, set_show_addr

set_show_addr(("127.0.0.1", port))

# Smoke test: should run without crashing.
preserve(global_exp)
preserve(local_exp)
show(global_exp)
show(local_exp)

# Check all features for global (including interactions).
for selector_key in global_exp.selector[global_exp.selector.columns[0]]:
preserve(global_exp, selector_key)

shutdown_show_server()


@pytest.mark.visual
@pytest.mark.slow
def test_ebm_adult():
from .... import preserve, show, shutdown_show_server, set_show_addr

data = adult_classification()
X = data["full"]["X"]
y = data["full"]["y"]
Expand All @@ -175,18 +176,28 @@ def test_ebm_adult():

valid_ebm(clf)

set_show_addr(("127.0.0.1", 6000))
global_exp = clf.explain_global()
local_exp = clf.explain_local(X[:5, :], y[:5])

# Smoke test: should run without crashing.
preserve(global_exp)
preserve(local_exp)
show(global_exp)
show(local_exp)
_smoke_test_explanations(global_exp, local_exp, 6000)

# Check all features for global (including interactions).
for selector_key in global_exp.selector[global_exp.selector.columns[0]]:
preserve(global_exp, selector_key)

shutdown_show_server()
@pytest.mark.visual
@pytest.mark.slow
def test_ebm_iris():
data = iris_classification()
X_train = data["train"]["X"]
y_train = data["train"]["y"]

X_test = data["test"]["X"]
y_test = data["test"]["y"]

clf = ExplainableBoostingClassifier()
clf.fit(X_train, y_train)

assert accuracy_score(y_test, clf.predict(X_test)) > 0.9

global_exp = clf.explain_global()
local_exp = clf.explain_local(X_test, y_test)

_smoke_test_explanations(global_exp, local_exp, 6001)
4 changes: 2 additions & 2 deletions python/interpret-core/interpret/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _detect_ipython():
from IPython import get_ipython

return get_ipython() is not None
except NameError:
except NameError: # pragma: no cover
return False


Expand All @@ -48,7 +48,7 @@ def _detect_ipython_zmq():
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
except NameError: # pragma: no cover
return False # Probably standard Python interpreter


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@


def test_environment_detector():
# Default
detector = EnvironmentDetector()
envs = detector.detect()
assert len(envs) == 0

# Check if assertion succeeds
detector.checks["always_true"] = lambda: True
envs = detector.detect()
assert len(envs) == 1 and envs[0] == "always_true"

0 comments on commit 67da4f4

Please sign in to comment.