Swin-Transformer
Swin-Transformer copied to clipboard
Using torch.bfloat16 to prevent overflow instead of default fp16 in AMP
Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. This seems to be a common issue while training the Swin Transformer.
BFloat16 has same integer bits compared to FP32 but less precision bits. If we want higher precision but also want to save GPU memory, then TensorFloat32 or tfloat32 can be used instead.
TF32 has less precision bits when compared to FP32, but 3 more integer bits compared to FP16. But TF32 can only be used on latest NVIDIA ampere gpus or newer.