graphnet
graphnet copied to clipboard
BinaryCrossEntropyLoss numerical instabilites
This issue came up in the GraphNet Slack:
When training with mixed precision and the BinaryClassificationTask with the BinaryCrossEntropyLoss you will get the following error:
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.
As a remedy to this, one should use the BinaryClassificationTaskLogits; however, currently there isn't an adequate LossFunction which uses the torch.nn.functional.binary_cross_entropy_with_logits
This is a simple fix which I have already implemented (look for PRs). Putting this for reference here.