ao icon indicating copy to clipboard operation
ao copied to clipboard

low precision training upcoming feature tracker

Open vkuzo opened this issue 1 year ago • 2 comments

This is a running list of planned features for low precision training. As features are completed we plan to delete them from this list, to keep things simple.

float8

performance

  • [in progress] optimize torch.compile performance for float8 tensorwise scaling/casting kernels
    • [fixed behind a flag, off by default] https://github.com/pytorch/pytorch/issues/130015
    • [planned] https://github.com/pytorch/pytorch/issues/128063
  • [in progress] ensure that float8 rowwise scaling is performant with TP and async TP https://github.com/pytorch/pytorch/issues/149990

distributed

  • [planned] verify integration with PP

new features

  • [2025-Q2] float8 grouped gemm support
  • [2025-Q2] better story for float8 training -> float8 inference
  • productionize no-compile version of float8 training (https://github.com/pytorch/ao/tree/main/torchao/prototype/float8nocompile, priority TBD)
  • [2025-Q2] weight gradient accumulation in float32
  • float8 SDPA (priority TBD)

ecosystem

  • [in progress] add torchtune integration (https://github.com/pytorch/torchtune/pull/2404)

other

  • [2025-Q2] expose float8 training via the quantize_ API
  • [2025-Q2] migrate torchao.float8 code to torchao.quantization for better unification with the rest of torchao, in a BC-preserving way

MX

pytorch/pytorch

  • [in progress] fp4_x2 dtype
  • [in progress] torch._scaled_mm for nvfp4, wrapping cuBLAS
  • [in progress] inductor performance work for mx block scaling fusion into surrounding ops: https://github.com/pytorch/pytorch/issues/149982
  • [2025-Q1] PT2 integration for e8m0 and fp4_x2
    • https://github.com/pytorch/pytorch/issues/147873

pytorch/torchao

  • [in progress] performance: https://github.com/pytorch/ao/issues/1768
  • [in progress] expose in quantize_ API

pytorch/torchtitan

  • [in progress] integrate mx training: https://github.com/pytorch/torchtitan/pull/1015

vkuzo avatar Jul 30 '24 15:07 vkuzo

Has SDPA supported FP8 already like flash-attention https://github.com/Dao-AILab/flash-attention/blob/main/hopper/benchmark_flash_attention_fp8.py does?

airMeng avatar Jan 17 '25 04:01 airMeng

Has SDPA supported FP8 already like flash-attention https://github.com/Dao-AILab/flash-attention/blob/main/hopper/benchmark_flash_attention_fp8.py does?

we would be interested in accepting a contribution for this! To me, the priority is not clear as we have not seen convincing evidence that SDPA + float8 can achieve good accuracy numerically.

vkuzo avatar Feb 19 '25 18:02 vkuzo

@vkuzo , is there a plan for MXFP8 all-gather? If so, by when is the feature expected to be enabled?

avizon-aws avatar Aug 15 '25 00:08 avizon-aws

@vkuzo , is there a plan for MXFP8 all-gather? If so, by when is the feature expected to be enabled?

we have considered this for float8 and mx, but it has not made the priority cut yet. Could you share some context about your use case, and whether you need this for FSDP or TP?

vkuzo avatar Aug 15 '25 14:08 vkuzo