Skip to content

Commit

Permalink
make regression metrics 'multioutput' behavior consistent with scikit…
Browse files Browse the repository at this point in the history
…-learn (fixes #818) (#820)

* make regression metrics 'multioutput' behavior consistent with scikit-learn

* add test on error messages

* linting
  • Loading branch information
jameslamb authored Apr 10, 2021
1 parent 27d8d37 commit db2e7d5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 12 deletions.
24 changes: 12 additions & 12 deletions dask_ml/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _check_sample_weight(sample_weight: Optional[ArrayLike]):
def _check_reg_targets(
y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str]
):
if multioutput != "uniform_average":
if multioutput is not None and multioutput != "uniform_average":
raise NotImplementedError("'multioutput' must be 'uniform_average'")

if y_true.ndim == 1:
Expand All @@ -40,12 +40,12 @@ def mean_squared_error(
_check_sample_weight(sample_weight)
output_errors = ((y_pred - y_true) ** 2).mean(axis=0)

if isinstance(multioutput, str):
if isinstance(multioutput, str) or multioutput is None:
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
multioutput = None
if compute:
return output_errors.compute()
else:
return output_errors
else:
raise ValueError("Weighted 'multioutput' not supported.")
result = output_errors.mean()
Expand All @@ -67,12 +67,12 @@ def mean_absolute_error(
_check_sample_weight(sample_weight)
output_errors = abs(y_pred - y_true).mean(axis=0)

if isinstance(multioutput, str):
if isinstance(multioutput, str) or multioutput is None:
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
multioutput = None
if compute:
return output_errors.compute()
else:
return output_errors
else:
raise ValueError("Weighted 'multioutput' not supported.")
result = output_errors.mean()
Expand Down Expand Up @@ -153,7 +153,7 @@ def r2_score(
compute: bool = True,
) -> ArrayLike:
_check_sample_weight(sample_weight)
_, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
_, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput)
weight = 1.0

numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8")
Expand Down
52 changes: 52 additions & 0 deletions tests/metrics/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numbers

import dask.array as da
import numpy as np
import pytest
import sklearn.metrics
from dask.array.utils import assert_eq

import dask_ml.metrics
from dask_ml._compat import SK_024
Expand Down Expand Up @@ -72,3 +74,53 @@ def test_mean_squared_log_error():
result = m1(a, b)
expected = m2(a, b)
assert abs(result - expected) < 1e-5


@pytest.mark.parametrize("multioutput", ["uniform_average", None])
def test_regression_metrics_unweighted_average_multioutput(metric_pairs, multioutput):
m1, m2 = metric_pairs

a = da.random.uniform(size=(100,), chunks=(25,))
b = da.random.uniform(size=(100,), chunks=(25,))

result = m1(a, b, multioutput=multioutput)
expected = m2(a, b, multioutput=multioutput)
assert abs(result - expected) < 1e-5


@pytest.mark.parametrize("compute", [True, False])
def test_regression_metrics_raw_values(metric_pairs, compute):
m1, m2 = metric_pairs

if m1.__name__ == "r2_score":
pytest.skip("r2_score does not support multioutput='raw_values'")

a = da.random.uniform(size=(100, 3), chunks=(25, 3))
b = da.random.uniform(size=(100, 3), chunks=(25, 3))

result = m1(a, b, multioutput="raw_values", compute=compute)
expected = m2(a, b, multioutput="raw_values")

if compute:
assert isinstance(result, np.ndarray)
else:
assert isinstance(result, da.Array)

assert_eq(result, expected)
assert result.shape == (3,)


def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs):
m1, _ = metric_pairs

a = da.random.uniform(size=(100, 3), chunks=(25, 3))
b = da.random.uniform(size=(100, 3), chunks=(25, 3))
weights = da.random.uniform(size=(3,))

if m1.__name__ == "r2_score":
error_msg = "'multioutput' must be 'uniform_average'"
else:
error_msg = "Weighted 'multioutput' not supported."

with pytest.raises((NotImplementedError, ValueError), match=error_msg):
_ = m1(a, b, multioutput=weights)

0 comments on commit db2e7d5

Please sign in to comment.