dask-ml icon indicating copy to clipboard operation
dask-ml copied to clipboard

The `log_loss`-function crashes when using mixed types

Open rsundqvist opened this issue 2 years ago • 0 comments

Describe the issue: When calling dask_ml.metrics.log_loss with mixed types, a ValueError is raised:

  File "<python-home>/lib/python3.11/site-packages/dask_ml/metrics/classification.py", line 106, in _log_loss_inner
    [sklearn.metrics.log_loss(x, y, sample_weight=sample_weight, **kwargs)]
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<python-home>/lib/python3.11/site-packages/sklearn/utils/_param_validation.py", line 211, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "<python-home>/lib/python3.11/site-packages/sklearn/metrics/_classification.py", line 2854, in log_loss
    check_consistent_length(y_pred, y_true, sample_weight)
  File "<python-home>/lib/python3.11/site-packages/sklearn/utils/validation.py", line 409, in check_consistent_length
    raise ValueError(
ValueError: Found input variables with inconsistent numbers of samples: [4, 2]

This works using "vanilla" sklearn/pandas/numpy functions and types.

Minimal Complete Verifiable Example:

from dask import array as da, dataframe as dd
from dask_ml.metrics import log_loss

y_true = dd.DataFrame.from_dict({"y_true": [True, False, True, False]}, npartitions=2)["y_true"]
y_pred = da.from_array([0, 1, 0, 1], chunks=2)

print(f"y_true is array: {log_loss(y_true.to_dask_array(), y_pred).compute()=}")
print(f"y_true is series: {log_loss(y_true, y_pred).compute()=}")

Anything else we need to know?:

  • Same issue is seen when the types of y_true and y_pred are switched.
  • A similar message is seen with identical types but differing npartitions and chunks-arguments.

Crashing on different partitioning probably makes sense, but I think it would be nice if the same-type requirement was documented, at least. Maybe it is and I just missed it.

Environment:

  • Dask version: 2023.8.0 (dask-ml: 2023.3.24)
  • Python version: 3.11.3
  • Operating System: Ubuntu 22.04.2 LTS. Running on Workstation Player 17 if that matters.
  • Install method (conda, pip, source): pip

rsundqvist avatar Aug 15 '23 11:08 rsundqvist