Keep Attention Softmax FP32 during FP16/ZeRO Training
Hi all,
Per the request of @ver217 from this discussion, I am opening an issue with the same name, comments, and question.
Recent discoveries from GLM-130 and researchers at Tsinghua have shown that keeping the attention softmax fp32 during training with fp16 and ZeRO leads to much greater stability at scale.
Attention Computation: FP32 Softmax Gradient shrink is a post-hoc technique to avoid training collapse. Essentially, the collapse is formed by an abnormal loss' gradient, either because of noisy data or overflow and underflow in the forward computing. We observe that the attention computation operation is the most likely to overflow or underflow in large language models. CogView shows that different attention heads have very different value scales for their attention scores, and some value scales can reach +1e4 or -1e-3. Such varied value scales can lead to frequent overflows or underflows under FP16 in the softmax computation. CogView proposes the Precision-Bottleneck Relaxation (PB-Relax) to mitigate the issue, which deducts the maximum absolute value in each head's attention score matrix before doing softmax. However, it turns out that PB-Relax is slow in GLM-130B's training, probably because finding the maximum and manipulating scalars in 96 attention score matrices sized 2048 * 2048 can be unfriendly to CUDA kernels. Finally, after a few weeks of arduous exploration, we find the fastest and easiest way to avoid the problem is to use FP32 in the softmax computation. Compared to the full FP16 computing, it hardly brings any speed loss but significantly improves the training stability.
Since ColossalAI handles the floating point precision during training, is there a specific recommended way to ensure that the softmax remains fp32 without being overridden automatically by the engine with fp16/ZeRO initialized? That way you can use fp16 and ZeRO enabled in the configuration while maintaining increased numerical stability.
Thank you,
Enrico
You can use torch.softmax(..., dtype=torch.float) to cast inputs to fp32 as a workaround. We may design a more flexible AMP in the future.
Hi @ver217 ,
Thank you for confirming that the softmax when initialized as a certain type will not be overridden by the ColossalAI engine. I wanted to ensure that when the configuration was set to fp16 there was a way to keep softmax as type fp32.
Additionally, I believe the team at Tsinghua had used AMP's FusedScaleMaskSoftmax to handle their use case. This allowed them to switch between fp16, bf16, and fp32.
Thank you,
Enrico
It's worth pointing out that torch.autocast will use float32 for softmax so if one is using the built-in torch.cuda.amp then this shouldn't be an issue.
We have updated a lot. This issue was closed due to inactivity. Thanks.