MPoL icon indicating copy to clipboard operation
MPoL copied to clipboard

Make loss functions and regularizers classes that inherit from `torch.nn`

Open iancze opened this issue 3 years ago • 3 comments

Currently our loss functions are coded as straightforward functions working on torch inputs. Some loss functions have additional parameters that are set at initialization, for example,

def entropy(cube, prior_intensity): ... where prior_intensity is a reference cube used in the evaluation of the actual target, cube.

This works fine, but can get cumbersome, especially when we are interfacing with a bunch of loss functions all at once (as in a cross-validation loop).

Can we take a page from the way PyTorch designs its loss functions and make most if not all loss functions classes that inherit from torch.nn? This would create objects that could be instantiated with default parameter values easily and generalize the calls to each parameter. For example, see MSE Loss.

This may have additional benefits (with reduce, say) if we think about batching and applications to multiple GPUs.

Does it additionally make sense to include the lambda terms as parameters of the loss object, too? @kadri-nizam do you have any thoughts from experience w/ your VAE architecture?

iancze avatar Feb 03 '23 20:02 iancze

I think making versions of the loss functions as torch modules is a great idea. I'd still keep the functional definition separate and import them when defining the module as it is more flexible (for developing and testing). I believe this is how PyTorch implements it; logic for the losses are in torch.nn.functional which gets used in the nn.Module version.

Does it additionally make sense to include the lambda terms as parameters of the loss object, too?

The lambda parameter doesn't change throughout the optimization run, right? If so then I'd include it in the argument during instantiation.

kadri-nizam avatar Mar 05 '23 01:03 kadri-nizam

Yes, the 'lambda' parameter remains fixed during an optimization run. In a cross-validation loop you'd want to try several different lambda values, so in that situation I guess you would need to re-instantiate the loss functions.

iancze avatar Mar 05 '23 18:03 iancze

Hi Ian,

Thank you for a productive discussion today! Here's an example of how I implemented the loss functions in my fork:

import torch.nn as nn

class TV(nn.Module):
    def __init__(self, λ: float, /) -> None:
        super().__init__()
        self.λ = λ

    def __repr__(self):
        return f"TV(λ={self.λ})"

    def forward(self, image: torch.Tensor):
        return self.λ * TV.functional(image)

    @staticmethod
    def functional(image: torch.Tensor) -> torch.Tensor:
        row_diff = torch.diff(image[:, :-1], dim=0).pow(2)
        column_diff = torch.diff(image[:-1, :], dim=1).pow(2)
        return torch.add(row_diff, column_diff).sqrt().sum()

The purpose for having the functional static method is to allow for easier testing -- just call TV.functional instead of the need to instantiate and all that.

I defined an abstract base class in my fork to specify requirements that a loss module in the repo must meet, but this is optional.

kadri-nizam avatar Jan 09 '24 17:01 kadri-nizam