Optim-wip: Composable loss improvements
This PR adds a few simple improvements to the CompositeLoss class and it's features.
- Added support for
__pos__and__abs__unary operators toCompositeLoss. These appear to be the only other basic operators that make sense to add support for. - Added
operator.floordivsupport. The current operator is being depreciated, but the operator symbol itself is likely not going be removed. Instead it's functionality will be changed: https://github.com/pytorch/pytorch/issues/58743 - Made
CompositeLoss's reduction operation a global variable that can be changed by users. This should improve the generality of the optim module, and it makes it possible to disable this aspect ofCompositeLoss. - Added composable
torch.meanandtorch.sumreduction operations toCompositeLoss. These are common operations, so there are likely use cases that can benefit from them. Example usage:loss_fn.mean()&loss_fn.sum(). - Added the
custom_composable_opfunction that should allow for the composability of many Python & PyTorch operations, as well as custom user operations. This should allow users to cover any operations that aren't covered by default in Captum. - Added
rmodule_opfunction for handling the 3 "r" versions of math operations. This helps simplify the code. - Added tests for the changes listed above.
-
2.0 ** loss_obj,2.0 / loss_obj, &2.0 // loss_objall work without the reduction op, so I've removed it for those cases.
Hi, thank you for making this, but I may miss some context/history here. Why do we need the "composable loss" in Captum?
Pytorch has already provided a convention for loss: function/callable-module wrapping some tensor operations. For example, if I need a loss made of others
def new_loss(output_tensor):
return nn.SomeLoss(output_tensor) + some_other_loss(output_tensor) + torch.linalg.norm(output_tensor)
Pytorch tensor has supported these basic arithmetic operations to modify & combine loss tensors. Why are we composing function/callable-module with arithmetic operations, instead of composing tensors? Pytorch supports more operation than what we have. And the "composable loss" cannot be composed with any existing Pytorch losses.
I think our optim loss can work the same without "composable loss" and even be more flexible. For example
deepdream = DeepDream(target)
layeractivation = LayerActivation(target)
def new_loss(targets_to_values: ModuleOutputMapping):
loss = deepdream(targets_to_values) + layeractivation(targets_to_values)
# can also use pytorch loss
return loss + nn.SomeLoss(targets_to_values[target])
@aobo-y Originally Captum’s loss functions were setup similar to the simple class-like functions like Lucid uses. Upon review we then changed the losses to use classes instead.
Ludwig (one of the main Lucid developers) designed the initial optim module to utilize a Lucid-like composable loss system. One of the main benefits of the composable loss system is ease of use and built-in target tracking (the list of targets has to created regardless of whether not we use composable losses, and doing it this way means the user doesn't have to repeat the loss targets in multiple locations). It also allows for easy-to-use handling of things like batch specific targeting.
This PR can be skipped for now.