Skip to content

Commit

Permalink
Filled out preserve method, including error handling. Refactor on spe…
Browse files Browse the repository at this point in the history
…c test.
  • Loading branch information
interpret-ml committed May 31, 2019
1 parent d16f7af commit 834ee40
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 52 deletions.
1 change: 0 additions & 1 deletion src/python/interpret/test/test_develop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ def test_print_debug_info():
# Very light check, just testing if the function runs.
print_debug_info()
assert 1 == 1

40 changes: 4 additions & 36 deletions src/python/interpret/test/test_explainers.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,15 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license

from ..data import ClassHistogram
from ..perf import ROC, RegressionPerf

from ..blackbox import LimeTabular
from ..blackbox import ShapKernel
from ..blackbox import MorrisSensitivity
from ..blackbox import PartialDependence

from ..glassbox import LogisticRegression, LinearRegression
from ..glassbox import ClassificationTree, RegressionTree
from ..glassbox import DecisionListClassifier
from ..glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor

from .utils import synthetic_classification
from .utils import synthetic_classification, get_all_explainers
from .utils import assert_valid_explanation, assert_valid_model_explainer

from ..glassbox import LogisticRegression

def test_spec_synthetic():
data_explainer_classes = [ClassHistogram]
perf_explainer_classes = [ROC, RegressionPerf]
model_explainer_classes = [
ClassificationTree,
DecisionListClassifier,
LogisticRegression,
ExplainableBoostingClassifier,
RegressionTree,
LinearRegression,
ExplainableBoostingRegressor,
]
blackbox_explainer_classes = [
LimeTabular,
ShapKernel,
MorrisSensitivity,
PartialDependence,
]
all_explainers = []
all_explainers.extend(model_explainer_classes)
all_explainers.extend(blackbox_explainer_classes)
all_explainers.extend(data_explainer_classes)
all_explainers.extend(perf_explainer_classes)

def test_spec_synthetic():
all_explainers = get_all_explainers()
data = synthetic_classification()
blackbox = LogisticRegression()
blackbox.fit(data["train"]["X"], data["train"]["y"])
Expand Down
1 change: 1 addition & 0 deletions src/python/interpret/test/test_interactive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license
# TODO: Testing for show/snap functions.

from ..visual.interactive import set_show_addr, get_show_addr, shutdown_show_server

Expand Down
40 changes: 40 additions & 0 deletions src/python/interpret/test/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# Copyright (c) 2019 Microsoft Corporation
# Distributed under the MIT software license

from ..data import ClassHistogram
from ..perf import ROC, RegressionPerf

from ..blackbox import LimeTabular
from ..blackbox import ShapKernel
from ..blackbox import MorrisSensitivity
from ..blackbox import PartialDependence

from ..glassbox import LogisticRegression, LinearRegression
from ..glassbox import ClassificationTree, RegressionTree
from ..glassbox import DecisionListClassifier
from ..glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
Expand All @@ -10,6 +23,33 @@
from sklearn.base import is_classifier


def get_all_explainers():
data_explainer_classes = [ClassHistogram]
perf_explainer_classes = [ROC, RegressionPerf]
model_explainer_classes = [
ClassificationTree,
DecisionListClassifier,
LogisticRegression,
ExplainableBoostingClassifier,
RegressionTree,
LinearRegression,
ExplainableBoostingRegressor,
]
blackbox_explainer_classes = [
LimeTabular,
ShapKernel,
MorrisSensitivity,
PartialDependence,
]
all_explainers = []
all_explainers.extend(model_explainer_classes)
all_explainers.extend(blackbox_explainer_classes)
all_explainers.extend(data_explainer_classes)
all_explainers.extend(perf_explainer_classes)

return all_explainers


def synthetic_regression():
dataset = _synthetic("regression")
return dataset
Expand Down
77 changes: 62 additions & 15 deletions src/python/interpret/visual/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def preserve(explanation, selector_key=None, file_name=None, **kwargs):
If file_name is not None the following occurs:
- For Plotly figures, saves to HTML using `plot`.
- For dataframes, saves to CSV using `to_csv`.
- For dataframes, saves to HTML using `to_html`.
- For strings (html), saves to HTML.
- For Dash components, fails with exception. This is currently not supported.
Expand All @@ -147,17 +147,58 @@ def preserve(explanation, selector_key=None, file_name=None, **kwargs):
None.
"""

try:
# Get explanation key
if selector_key is None:
key = None
else:
series = explanation.selector[explanation.selector.columns[0]]
key = series[series == selector_key].index[0]

# Get visual object
visual = explanation.visualize(key=key)

# Output to front-end/file
_preserve_output(
explanation.name,
visual,
selector_key=selector_key,
file_name=file_name,
**kwargs
)
return None
except Exception as e:
log.error(e, exc_info=True)
raise e


def _preserve_output(
explanation_name, visual, selector_key=None, file_name=None, **kwargs
):
from plotly.offline import iplot, plot, init_notebook_mode
from IPython.display import display, HTML
from IPython.display import display, display_html
from base64 import b64encode

init_notebook_mode(connected=True)

if selector_key is None:
key = None
else:
series = explanation.selector[explanation.selector.columns[0]]
key = series[series == selector_key].index[0]
def render_html(html_string):
base64_html = b64encode(html_string.encode("utf-8")).decode("ascii")
final_html = """<iframe src="data:text/html;base64,{data}" width="100%" height=400 frameBorder="0"></iframe>""".format(
data=base64_html
)
display_html(final_html, raw=True)

if visual is None:
msg = "No visualization for explanation [{0}] with selector_key [{1}]".format(
explanation_name, selector_key
)
log.error(msg)
if file_name is None:
render_html(msg)
else:
pass
return False

visual = explanation.visualize(key=key)
if isinstance(visual, go.Figure):
if file_name is None:
iplot(visual, **kwargs)
Expand All @@ -167,18 +208,24 @@ def preserve(explanation, selector_key=None, file_name=None, **kwargs):
if file_name is None:
display(visual, **kwargs)
else:
visual.to_csv(file_name, **kwargs)
visual.to_html(file_name, **kwargs)
elif isinstance(visual, str):
if file_name is None:
with(file_name, "w") as f:
f.write(visual)
render_html(visual)
else:
HTML(visual, **kwargs)
with open(file_name, "w") as f:
f.write(visual)
elif isinstance(visual, dash_base.Component):
msg = "Preserving dash components is currently not supported."
raise Exception(msg)
if file_name is None:
render_html(msg)
log.error(msg)
return False
else:
msg = "Visualization cannot be preserved for type: {0}.".format(type(visual))
raise Exception(msg)
if file_name is None:
render_html(msg)
log.error(msg)
return False

return None
return True

0 comments on commit 834ee40

Please sign in to comment.