Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Behavior of regression metrics when multioutput=None is inconsistent with scikit-learn #818

Closed
jameslamb opened this issue Apr 6, 2021 · 0 comments · Fixed by #820
Closed

Comments

@jameslamb
Copy link
Member

What happened:

Using multioutput=None with the regression metrics in dask_ml.metrics.regression results in an error.

What you expected to happen:

I expected the behavior to be the same as the equivalent scikit-learn metrics functions, where multioutput=None means "all elements have the same weight".

Minimal Complete Verifiable Example:

In scikit-learn, the value of multioutput is passed through to np.average().

https://github.com/scikit-learn/scikit-learn/blob/877c6e6db42006445ccf0695c0dde3294ff4dd4a/sklearn/metrics/_regression.py#L195

np.average() treats the value None as "equally-weighted", which is the same as passing multioutput = "uniform_average". From https://numpy.org/doc/stable/reference/generated/numpy.average.html

If weights=None, then all data in a are assumed to have a weight equal to one.

import numpy as np
from sklearn.metrics import mean_squared_error

a = np.random.uniform(size=(100, 3))
b = np.random.uniform(size=(100, 3))

raw_output = mean_squared_error(a, b, multioutput="raw_values")
print(raw_output)

# [0.17118814 0.15964742 0.13095381]

mean_squared_error(a, b, multioutput=None)

# 0.15392978995168105

In dask-ml, passing multioutput=None results in an error.

import dask.array as da
from dask_ml.metrics.regression import mean_squared_error

a = da.random.uniform(size=(100, 3))
b = da.random.uniform(size=(100, 3))

raw_output = mean_squared_error(a, b, multioutput="raw_values")
print(raw_output)
# dask.array<mean_agg-aggregate, shape=(3,), dtype=float64, chunksize=(3,), chunktype=numpy.ndarray>

mean_squared_error(a, b, multioutput=None)
# ValueError: Weighted 'multioutput' not supported.

Anything else we need to know?:

Environment:

  • Dask version:
output of 'conda list | grep -E "dask|distributed"'
dask                      2021.3.0           pyhd3eb1b0_0  
dask-core                 2021.3.0           pyhd3eb1b0_0  
dask-glm                  0.2.0                    pypi_0    pypi
dask-saturn               0.2.2                    pypi_0    pypi
dask-sphinx-theme         1.3.5                    pypi_0    pypi
distributed               2021.3.0         py37h06a4308_0  
  • Python version, Operating System
output of 'conda info'

     active environment : None
       user config file : /home/jlamb/.condarc
 populated config files : 
          conda version : 4.9.2
    conda-build version : not installed
         python version : 3.7.6.final.0
       virtual packages : __glibc=2.27=0
                          __unix=0=0
                          __archspec=1=x86_64
       base environment : /home/jlamb/miniconda3  (writable)
           channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /home/jlamb/miniconda3/pkgs
                          /home/jlamb/.conda/pkgs
       envs directories : /home/jlamb/miniconda3/envs
                          /home/jlamb/.conda/envs
               platform : linux-64
             user-agent : conda/4.9.2 requests/2.22.0 CPython/3.7.6 Linux/5.4.0-70-generic ubuntu/18.04.4 glibc/2.27
                UID:GID : 1000:1000
             netrc file : None
           offline mode : False
  • Install method (conda, pip, source): installed from source as of f5e5bb4, using python setup.py install.
TomAugspurger pushed a commit that referenced this issue Apr 10, 2021
…-learn (fixes #818) (#820)

* make regression metrics 'multioutput' behavior consistent with scikit-learn

* add test on error messages

* linting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant