torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Classification refactor

Open SkafteNicki opened this issue 3 years ago • 3 comments

The classification package is long overdue for a refactor as we are seeing a rising number of issues that either request new features that are hard to implement in the current codebase, a disagreement between what users expect the metrics are doing and what they are actually doing.

A full list of issues marked that should be taken care of with the refactor can be found here

The refactor hope to adress the following problems:

  • Maintainability: The core part of the classification metrics was written back when torchmetrics was pl.metrics by an contributor. We should maybe have been more thoroughly in the review phase, because the code has been hard to maintain. This refactor should hopefully help adress this by lowering code complexity.
  • Consistency: The classification package have a number of consistency issues. This issue gives an overview of some of the consistency issues, but essentially num_classes=2 sometimes means doing binary classification and sometimes means multiclass classification (which differs in their definition) depending on what metric you are using.
  • Expectations: There are a number of cases where the default choice of arguments currently does not match users expection, because we differ a bit from how sklearn are handeling some cases. This refactor will adress these differences.
  • Performance: we have received feedback that our implementations are very slow. While these comparisons are not always fair (comparing a one line implementation of accuracy vs. our modular implementation which is much more general), it is fair to say that some improvements can be made.

Proposed solution

The proposed solution is to split each metric into three seperate metric instances

  • BinaryMetricName
  • MultiClassMetricName
  • MultiLabelMetricName

For example Accuracy will be split into BinaryAccuracy, MultiClassAccuracy, MultiLabelAccuracy. This solution directly solves a number of problems:

  1. While most metrics indeed support all three tasks, some only support a single or 2. Ranaming and splitting all metrics will make it much more clear what metric support what mode.
  2. Some performance is lost on during a lot of input validation. By forcing the user beforehand to specify what task we are evaluating we can reduce the amount of input validation needed.
  3. Code complexity: instead of all metrics essentially having if-else statements based on what task we are trying to solve, each metric will have a much more clear computational path.
  4. Input arguments: take threshold and topk as examples, which are current arguments to the Accuracy, F1 ect. metrics. threshold should only be set for binary and multilabel and topk should only be specified for multiclass. Dividing into seperate metrics helps communicate what arguments have a influence on the computations going on.

Alternatives

  1. We keep everything in one class but introduce an new (required) argument:

    class Accuracy(Metric):
    	def __init__(
    		self, 
    		mode: Literal['binary', 'multiclass', 'multilabel']
    		...
    	):
    		...
    

    This alternative is not directly in opposition to the main proposed solution. If requested by users we could still provide a single class that just wraps the three individual metric classes into 1.

  2. We keep the outer API the exact same and try to clean up the internals. This will most likely only address some of the current problems.

Integration

  • PL: should really not be a problem as core PL does not depend on the classification package, only on the base torchmetrics.Metric class, which should not need to be touched doing the refactor. Some examples may need to be updated.
  • Flash: some changes would need to be made. We still expect that it will be minimal as initialization of a flash Task should already contain all information necessary to determine what class should be used. cc: @ethanwharris

Deprecation

The goal is to have the hole classification package refactored/cleaned up as the major work in 0.10. All current classification metrics will be given a deprecation warning and users will have until 0.11 to refactor their code to use the new classes.

While we are developing the new package we will have a freeze on new metrics in the classification package. We will still happily accept new metrics for other domains.

Documentation impact

Up until v0.8, this change would have made our documentation very annoying to scroll through as everything was in one central page. This change would essentially make the documentation for classification 3 times harder to navigate.

However, from v0.8 we changed it to have one page per metric. For this refactor we would keep one page per core metric e.g. Accuracy, Precision, Recall etc. and each page would then list every version of the metric.

Development

The development can essentially be divided into 3 phases:

  1. Development of a generalized StatScore and ConfusionMatrix class for all three modes. Many classification metrics can be calculated from these statistics.
  2. Subclass the generalized classes into specific metrics.
  3. Deal with metrics that do not fall into the one of the two generalized classes (about 1/3 of the classification metrics currently)

Main part of the refactor will be done by @SkafteNicki and @justusschock, with support from the rest of the core metrics team. We may be open for contributions for step 2 as it should be fairly simple sub classing and copy-paste work. Development should start within 2 weeks time.

Any feedback is appreciated :)

SkafteNicki avatar May 03 '22 08:05 SkafteNicki

I like that we are thinking of simplifying this for the users. The classification module has grown organically - many more options have been added over time. I agree a refactor is appropriate here.

I like the alternative approach 1 where we keep the current basic classes as light wrappers. These are easy to remember and frequently used (Accuracy, F1, Precision ...) by a large user base. If we go with the wrapper class and do validation inside of them, perhaps we can recommend the specific classes BinaryMetricName, MultiClassMetricName, MultiLabelMetricName for the depending on the error and use case.

awaelchli avatar May 03 '22 11:05 awaelchli

Sounds good, I agree with @awaelchli that it would be nice to keep the base Accuracy etc. classes with e.g. the mode argument seems reasonable. Flash can be updated to work with either API :smiley:

ethanwharris avatar May 03 '22 11:05 ethanwharris

While it is not directly related to the discussion, in theory, this issue may also be relevant, since it highlights one more aspect of accuracy-like metrics that may be taken into account during the refactoring.

Yura52 avatar Jun 01 '22 09:06 Yura52