[RFC] Loss Functions in Torchvision
🚀 Feature
A loss functions API in torchvision.
Motivation
The request is simple, we have loss functions available in torchvision
E.g. sigmoid_focal_loss , l1_loss. But these are quite scattered and we have to use torchvision.ops.sigmoid_focal_loss etc.
In future, we might need to include further loss functions. E.g. dice_loss
Since loss functions are differentiable we can put them under nn.
We can have
torchvision.nn.losses.sigmoid_focal_loss and so on.
This keeps the scope of nn open for other differentiable functions such as layers, etc.
Pitch
These losses are very specific and pertain to vision domain. These are really useful and in general not tied to any specific model.
Though the loss functions that we keep are usually in torch. If we keep under nn namespace, future migration stays simple.
instead of torchvision.nn.sigmoid_focal_loss it would be torch.nn.sigmoid_focal_loss.
This Pitch comes from the above issues. More Loss Functions
Alternatives
Alternatively, this should go in torch. But if we keep the above idea, we can support them in torchvision and later deprecate and move to torch (when needed).
Currently, we include them under ops but it is actually not an operation it is a differentiable loss function.
Whereas other ops are not differentiable and perform transformations / some manipulation over boxes/layers.
Additional context
Here is a list of loss functions we would like to include.
- [x] LabelSmoothing Loss https://github.com/pytorch/pytorch/pull/63122
- [x] SoftTarget CrossEntropy https://github.com/pytorch/pytorch/pull/61044
- [x] Huber Loss https://github.com/pytorch/pytorch/pull/50553
- [ ] Barron loss Implemented in classy vision
- [ ] JSD Loss
- [ ] Dice Loss
- [ ] Poly Loss
- [x] gIoU Loss Used in DETR.
- [x] Refactor Current Focal Loss from ops to nn.
- [x] Refactor FRCNN Smooth L1 Loss to nn.
- [ ] Super Loss https://github.com/pytorch/pytorch/issues/49851
- [ ] TripletMarginLoss This has similar issue to
LabelSmoothing.TripletMarginLossis supported in PyTorch but we use a variant of it in torchvision references for similarity search. - [ ] DeepLabCELoss This is implemented in Detectron2, but in torchvision references and model training we use
nn.CrossEntropy()with a little modification to aux loss. - [ ] Multi Class Focal Loss
- [ ] PSNR Loss Also PSNR as
torchvision.opswill be nice. - [x] Distance-IoU & Complete-IoU loss - see here - #5776 #5786
- [ ] SioU loss
- [ ] Federated loss
- [ ] Poly Loss https://github.com/pytorch/pytorch/issues/76732
- [ ] Tversky Loss
References
We can refer to Kornia, Fvcore and few PyTorch issues that need this feature.
cc. @dongreenberg @cpuhrsch I think this is very reasonable, there is a similar work going for torchaudio too. We should resume the library naming convention discussion and wrap it up to provide a comprehensive solution for loss/metrics.
@oke-aditya The domain team had a brief discussion on this;
- we agree that domain specific loss functions are coming up.
- But we would like to hold off on creating a dedicated module untill we actually have a good number of functions that fall into the category. It's easy to add such a module, but once added we can not remove it.
- meanwhile we can update the documentation and add a new category so that the existing loss is easy to find.
@oke-aditya What do you think?
cc @fmassa
I agree with your thoughts @mthrok . It can be too early to call for such API. Yes it will be nice to update documentation.
I have a use-case that requires LabelSmoothing. Unfortunately CrossEntropyLoss does not support it in PyTorch (https://github.com/pytorch/pytorch/issues/7455). This is a highly requested feature but unfortunately it's been blocked for more than 2 years. Thus I'm tempted to add it on TorchVision side until the above is resolved, but as @oke-aditya pointed out there is no great place to put it.
@oke-aditya It might be worth keeping track of the losses requested to be added here, so that we can see if we have a critical mass to move this forward. Would you be able to update the ticket description with the list of the current loss functions we want to add on the domain side?
Great point @datumbox Sure :smile: I will update the Issue description. Let's keep this issue for tracking purpose. Feel free to modify it if I miss something.
I guess this issue still needs discussion and there is no point in wanting to contribute a loss for now? :thinking:
@yassineAlouini Wow what a coincidence! Today I was working on something related. :)
At #5444, I have an experimental private function that makes it possible to switch between losses. There are no plans for it to become public any time soon but I was thinking of implementing the Distance-IoU & Complete-IoU losses listed on the ticket.
If you are interested in contributing them, let me know.
Adding cIoU and dIoU should be staight forward. It's been a while in my mind too.
Let me know if you need help. Or I can pick it up as well :) @yassineAlouini
@oke-aditya @yassineAlouini It would be awesome if you could help on the development and review of these 2 losses. For now will put them flat on the ops package similar to giou.
Sure 😃
Yes, it works for me, thanks @datumbox. 👌 Which one should I pick? Do you have a preference @oke-aditya?
Pick Anything you like :)
Thanks @oke-aditya. Let me give dIoU a try and see if I can also do the cIoU next. I guess you can help me with review since you know better this part of the repo. I can work on this around 1 day per week. 👌
@oke-aditya, Have you started working on CioU? If not can I take CioU? Thanks.
Sure @abhi-glitchhg feel free to take it. I'm happy reviewing the PR.
In case you didn't know here is detectron 2 implementation of both of these https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py#L66
Logistics question: Should I create an issue and add some details or is it enough to start a branch from my work and then do a PR later? @datumbox
This issue tracks it, You can create a fresh branch from main and raise a PR :)
@yassineAlouini What Aditya said ^. :)
No need for a separate ticket, we got plenty that mention it already. When you bring the PR, I'll tag it accordingly. Just make sure you mark is as draft until you are ready for review.
Should I also add a test for CIOU loss?
I tried finding a test for generalised iou loss in test_ops.py but did not find any. So just want to confirm.
Yeah the tests are not present as of yet see #5688. We can add it along with the PR, mostly you can check cases such as overlapping boxes, side by side boxes, etc. Something on lines of https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
But first we could have look at the implementation?
@oke-aditya @abhi-glitchhg I think we need to keep our codes synched otherwise we might end up implementing something that is quite different for cIoU and dIoU. I will try to push a draft of my MR as soon as possible (around the end of this week or start of next week).
hey don't worry @yassineAlouini code reviews will make sure we are consistent :) We will sync it up. Feel free to work independently.
Agreed. Might be worth to start with the implementations and by then the test should be in. I'll keep an eye for your PRs as I'm currently keeping track of all related work at #5410.
Some progress here (it is still a draft): https://github.com/pytorch/vision/pull/5786
Seems variety of IoU and it's losses keep evolving. Now after g d and c IoU we have sIoU
https://arxiv.org/abs/2205.12740
(How many alphabets will IoU get 4/26 currently 😁😁)
I think SSIM also good candidate here.
And there exists a issue in the pytorch repo. - https://github.com/pytorch/pytorch/issues/6934
Any thoughts? @datumbox @oke-aditya
Yes but we don't have any task or usage in torchvision for SSIM.
Rather than opening a new issue about focal loss, I figured it might be simplest to comment here. Is there a timeline for reorganizing sigmoid_focal_loss and/or upgrading it to multiclass? It would also be useful to have it as a subclass of _WeightedLoss from torch.nn.modules. Thanks!