Skip to content

Commit

Permalink
feat(automl): add support for feature importance (#9816)
Browse files Browse the repository at this point in the history
Previously feature importance is added with a params parameter, the
user has to set params = {"feature_importance": "true"}. This PR
simplifies the logic, the user just have to pass feature_importance =
True.
  • Loading branch information
helinwang authored and busunkim96 committed Nov 16, 2019
1 parent 972e5b4 commit 8d745d0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
12 changes: 8 additions & 4 deletions automl/google/cloud/automl_v1beta1/tables/tables_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,7 +2596,7 @@ def predict(
model=None,
model_name=None,
model_display_name=None,
params=None,
feature_importance=False,
project=None,
region=None,
**kwargs
Expand Down Expand Up @@ -2643,9 +2643,9 @@ def predict(
The `model` instance you want to predict with . This must be
supplied if `model_display_name` or `model_name` are not
supplied.
params (dict[str, str]):
`feature_importance` can be set as True to enable local
explainability. The default is false.
feature_importance (bool):
True if enable feature importance explainability. The default is
False.
Returns:
A :class:`~google.cloud.automl_v1beta1.types.PredictResponse`
Expand Down Expand Up @@ -2687,6 +2687,10 @@ def predict(

request = {"row": {"values": values}}

params = None
if feature_importance:
params = {"feature_importance": "true"}

return self.prediction_client.predict(model.name, request, params, **kwargs)

def batch_predict(
Expand Down
19 changes: 19 additions & 0 deletions automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,25 @@ def test_predict_from_dict(self):
None,
)

def test_predict_from_dict_with_feature_importance(self):
data_type = mock.Mock(type_code=data_types_pb2.CATEGORY)
column_spec_a = mock.Mock(display_name="a", data_type=data_type)
column_spec_b = mock.Mock(display_name="b", data_type=data_type)
model_metadata = mock.Mock(
input_feature_column_specs=[column_spec_a, column_spec_b]
)
model = mock.Mock()
model.configure_mock(tables_model_metadata=model_metadata, name="my_model")
client = self.tables_client({"get_model.return_value": model}, {})
client.predict(
{"a": "1", "b": "2"}, model_name="my_model", feature_importance=True
)
client.prediction_client.predict.assert_called_with(
"my_model",
{"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}},
{"feature_importance": "true"},
)

def test_predict_from_dict_missing(self):
data_type = mock.Mock(type_code=data_types_pb2.CATEGORY)
column_spec_a = mock.Mock(display_name="a", data_type=data_type)
Expand Down

0 comments on commit 8d745d0

Please sign in to comment.