Stat_Scores does not behave as expected
🐛 Bug
torchmetrics.functional.stat_scores gives out wrong values
To Reproduce
Steps to reproduce the behavior:
preds = torch.tensor([1, 0, 0, 1])
target = torch.tensor([1, 1, 0, 0])
stat_scores(preds,target)
The output of this function is as follows:
tensor([2, 2, 2, 2, 4])
Expected behavior
As this function is supposed to print out True Positives, False Positives, True Negatives, False Negatives and TP+FN, In the given example, all the values should be equal to 1. expected behaviour: tensor([1, 1, 1, 1, 2])
But it is instead printing out 2.
Environment
- PyTorch 1.10
- Linux (On Colab):
- Pytorch Preinstalled
- Python version: 3.7.12
- Tried on both CPU and GPU
Hi! thanks for your contribution!, great first issue!
@SkafteNicki or @lucadiliello mind have look? :]
I found that the bug is caused by the checks on the following lines:
https://github.com/PyTorchLightning/metrics/blob/e2f7105ddd97371042b9a5bc5a49404109d42c83/torchmetrics/utilities/checks.py#L418 and
https://github.com/PyTorchLightning/metrics/blob/e2f7105ddd97371042b9a5bc5a49404109d42c83/torchmetrics/utilities/checks.py#L421
and the fact that the default value of multiclass is None.
However, by changing the checks to
if not multiclass:
and
if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and multiclass) or multiclass:
(which should be equivalent to just):
if multiclass:
some other tests are going to fail.
Since I'm not expert of the classification metrics package, I cannot go further for time reasons :/
Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes
Small recap: This issue describe that metric stat_scores is not correctly working in the binary setting. The problem with the current implementation is that the metrics are calculated as an sum of the statistics over both the 0 and 1 class, which is wrong.
After, the refactor if we introduce the specialized binary_stat_scores:
from torchmetrics.functional import binary_stat_scores
import torch
preds = torch.tensor([1, 0, 0, 1])
target = torch.tensor([1, 1, 0, 0])
binary_stat_scores(preds, target) # tensor([1, 1, 1, 1, 2])
which is what is expected for this example (as stated in the issue description). Sorry for the confusion that this have given rise to. Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.