Swin-Transformer icon indicating copy to clipboard operation
Swin-Transformer copied to clipboard

Using torch.bfloat16 to prevent overflow instead of default fp16 in AMP

Open rajeevgl01 opened this issue 2 years ago • 0 comments

Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. This seems to be a common issue while training the Swin Transformer.

BFloat16 has same integer bits compared to FP32 but less precision bits. If we want higher precision but also want to save GPU memory, then TensorFloat32 or tfloat32 can be used instead.

TF32 has less precision bits when compared to FP32, but 3 more integer bits compared to FP16. But TF32 can only be used on latest NVIDIA ampere gpus or newer.

rajeevgl01 avatar Jan 12 '24 20:01 rajeevgl01