ao icon indicating copy to clipboard operation
ao copied to clipboard

Next steps for mxfp8 MoE training

Open danielvegamyhre opened this issue 5 months ago • 2 comments

  • [ ] mxfp8 all2all -> stay in mxfp8 through the token shuffle -> mxfp8 grouped gemm
    • [x] initial mxfp8 all2all impl (drop in replacement for all_to_all_single_autograd, sync required)
    • [ ] mxfp8 token shuffle (modified version of this Triton kernel which also permutes scales to be in the same order as their associated tokens)
    • [ ] Extend mxfp8 grouped gemm autograd func to also accept pre-quantized inputs
  • [ ] Improve 3d expert weight mxfp8 quanitzation CUDA kernel (currently at 65-70% peak memory bandwidth, should target 85%+ like the other mxfp8 quantization kernels)
  • [ ] Investigate if we can write e8m0 scales directly to blocked format, instead of running separate conversion kernels.
  • [ ] Improve mxfp8 grouped gemm performance for small K dim (dsv3/kimi shapes). Currently we see less speedup for small, skinny experts than larger experts like llama4 has. We need to improve this since dsv3/kimi base models are so popular now.
  • [ ] unify dense + moe mxfp8 training code bases

danielvegamyhre avatar Nov 24 '25 15:11 danielvegamyhre

Is there a planned support for MXFP8 all gather as well? I can see that it would be helpful for TP/SP activation all gather, FSDP weight gathers?

avizon-aws avatar Nov 24 '25 18:11 avizon-aws

Is there a planned support for MXFP8 all gather as well? I can see that it would be helpful for TP/SP activation all gather, FSDP weight gathers?

Not at the moment, but we would gladly review PRs for such things!

danielvegamyhre avatar Nov 24 '25 18:11 danielvegamyhre

Hi @danielvegamyhre , i have created a PR for supporting MXFP8 all gather, would be great if you could review it: https://github.com/pytorch/ao/pull/3435

avizon-aws avatar Dec 04 '25 07:12 avizon-aws