diff --git a/dask_ml/metrics/regression.py b/dask_ml/metrics/regression.py index d1849ef88..b28834e76 100644 --- a/dask_ml/metrics/regression.py +++ b/dask_ml/metrics/regression.py @@ -16,7 +16,7 @@ def _check_sample_weight(sample_weight: Optional[ArrayLike]): def _check_reg_targets( y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str] ): - if multioutput != "uniform_average": + if multioutput is not None and multioutput != "uniform_average": raise NotImplementedError("'multioutput' must be 'uniform_average'") if y_true.ndim == 1: @@ -40,12 +40,12 @@ def mean_squared_error( _check_sample_weight(sample_weight) output_errors = ((y_pred - y_true) ** 2).mean(axis=0) - if isinstance(multioutput, str): + if isinstance(multioutput, str) or multioutput is None: if multioutput == "raw_values": - return output_errors - elif multioutput == "uniform_average": - # pass None as weights to np.average: uniform mean - multioutput = None + if compute: + return output_errors.compute() + else: + return output_errors else: raise ValueError("Weighted 'multioutput' not supported.") result = output_errors.mean() @@ -67,12 +67,12 @@ def mean_absolute_error( _check_sample_weight(sample_weight) output_errors = abs(y_pred - y_true).mean(axis=0) - if isinstance(multioutput, str): + if isinstance(multioutput, str) or multioutput is None: if multioutput == "raw_values": - return output_errors - elif multioutput == "uniform_average": - # pass None as weights to np.average: uniform mean - multioutput = None + if compute: + return output_errors.compute() + else: + return output_errors else: raise ValueError("Weighted 'multioutput' not supported.") result = output_errors.mean() @@ -90,7 +90,7 @@ def r2_score( compute: bool = True, ) -> ArrayLike: _check_sample_weight(sample_weight) - _, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput) + _, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput) weight = 1.0 numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8") diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py index dfdc5480c..8005a515e 100644 --- a/tests/metrics/test_regression.py +++ b/tests/metrics/test_regression.py @@ -1,8 +1,10 @@ import numbers import dask.array as da +import numpy as np import pytest import sklearn.metrics +from dask.array.utils import assert_eq import dask_ml.metrics @@ -60,3 +62,53 @@ def test_mean_squared_log_error(): result = m1(a, b) expected = m2(a, b) assert abs(result - expected) < 1e-5 + + +@pytest.mark.parametrize("multioutput", ["uniform_average", None]) +def test_regression_metrics_unweighted_average_multioutput(metric_pairs, multioutput): + m1, m2 = metric_pairs + + a = da.random.uniform(size=(100,), chunks=(25,)) + b = da.random.uniform(size=(100,), chunks=(25,)) + + result = m1(a, b, multioutput=multioutput) + expected = m2(a, b, multioutput=multioutput) + assert abs(result - expected) < 1e-5 + + +@pytest.mark.parametrize("compute", [True, False]) +def test_regression_metrics_raw_values(metric_pairs, compute): + m1, m2 = metric_pairs + + if m1.__name__ == "r2_score": + pytest.skip("r2_score does not support multioutput='raw_values'") + + a = da.random.uniform(size=(100, 3), chunks=(25, 3)) + b = da.random.uniform(size=(100, 3), chunks=(25, 3)) + + result = m1(a, b, multioutput="raw_values", compute=compute) + expected = m2(a, b, multioutput="raw_values") + + if compute: + assert isinstance(result, np.ndarray) + else: + assert isinstance(result, da.Array) + + assert_eq(result, expected) + assert result.shape == (3,) + + +def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs): + m1, _ = metric_pairs + + a = da.random.uniform(size=(100, 3), chunks=(25, 3)) + b = da.random.uniform(size=(100, 3), chunks=(25, 3)) + weights = da.random.uniform(size=(3,)) + + if m1.__name__ == "r2_score": + error_msg = "'multioutput' must be 'uniform_average'" + else: + error_msg = "Weighted 'multioutput' not supported." + + with pytest.raises((NotImplementedError, ValueError), match=error_msg): + _ = m1(a, b, multioutput=weights)