low precision training upcoming feature tracker
This is a running list of planned features for low precision training. As features are completed we plan to delete them from this list, to keep things simple.
float8
performance
- [in progress] optimize torch.compile performance for float8 tensorwise scaling/casting kernels
- [fixed behind a flag, off by default] https://github.com/pytorch/pytorch/issues/130015
- [planned] https://github.com/pytorch/pytorch/issues/128063
- [in progress] ensure that float8 rowwise scaling is performant with TP and async TP https://github.com/pytorch/pytorch/issues/149990
distributed
- [planned] verify integration with PP
new features
- [2025-Q2] float8 grouped gemm support
- [2025-Q2] better story for float8 training -> float8 inference
- productionize no-compile version of float8 training (https://github.com/pytorch/ao/tree/main/torchao/prototype/float8nocompile, priority TBD)
- [2025-Q2] weight gradient accumulation in float32
- float8 SDPA (priority TBD)
ecosystem
- [in progress] add torchtune integration (https://github.com/pytorch/torchtune/pull/2404)
other
- [2025-Q2] expose float8 training via the quantize_ API
- [2025-Q2] migrate
torchao.float8code totorchao.quantizationfor better unification with the rest of torchao, in a BC-preserving way
MX
pytorch/pytorch
- [in progress] fp4_x2 dtype
- [in progress] torch._scaled_mm for nvfp4, wrapping cuBLAS
- [in progress] inductor performance work for mx block scaling fusion into surrounding ops: https://github.com/pytorch/pytorch/issues/149982
- [2025-Q1] PT2 integration for e8m0 and fp4_x2
- https://github.com/pytorch/pytorch/issues/147873
pytorch/torchao
- [in progress] performance: https://github.com/pytorch/ao/issues/1768
- [in progress] expose in quantize_ API
pytorch/torchtitan
- [in progress] integrate mx training: https://github.com/pytorch/torchtitan/pull/1015
Has SDPA supported FP8 already like flash-attention https://github.com/Dao-AILab/flash-attention/blob/main/hopper/benchmark_flash_attention_fp8.py does?
Has SDPA supported FP8 already like flash-attention https://github.com/Dao-AILab/flash-attention/blob/main/hopper/benchmark_flash_attention_fp8.py does?
we would be interested in accepting a contribution for this! To me, the priority is not clear as we have not seen convincing evidence that SDPA + float8 can achieve good accuracy numerically.
@vkuzo , is there a plan for MXFP8 all-gather? If so, by when is the feature expected to be enabled?
@vkuzo , is there a plan for MXFP8 all-gather? If so, by when is the feature expected to be enabled?
we have considered this for float8 and mx, but it has not made the priority cut yet. Could you share some context about your use case, and whether you need this for FSDP or TP?