Added support of NANOO fp8.
What does this PR do?
This PR adds support of fp8 dot op for NANOO fp8 data formats (an alternative genre to the OCP fp8 data formats, which is used by NVIDIA GPU).
There are several different genres of fp8 formats used by different HW vendors. Two popular genres include
- OCP fp8, which is used natively on NVIDIA H100
- NANOO fp8, which is used natively on AMD MI300 and Graphcore HW.
These two genres of fp8 formats work very similarly. This PR is to enable support of NANOO fp8 as it is also now supported in JAX and XLA. It would enable usage of fp8 dot op on AMD MI300 GPU.
References:
- OCP fp8 paper: https://arxiv.org/abs/2209.05433
- NANOO fp8 paper: https://arxiv.org/abs/2206.02915
- JAX PR: https://github.com/google/jax/pull/21376
- XLA PR: https://github.com/openxla/xla/pull/9531
@levskaya I noticed that you have reviewed several PRs regarding fp8. Could you take a look at this one?
@levskaya Could you kindly serve as the reviewer for this PR?
@levskaya Thanks for the review. I have updated the PR to address the concerns. Could you take a look at the updates?
Thanks for the fixes! We may need to do some tiny rebasing of simple things as the codebase just migrated to a python minver of 3.10.
Thanks! Do you need me to rebase it to the tip of the tree?
Yes to tip as of today should have the 3.10 minver updates. Also, I'm seeing this failure in the tests:
FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype0 - TypeError: missing a required argument: 'amax_history'
FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype1 - TypeError: missing a required argument: 'amax_history'
could you fix that?
Yes to tip as of today should have the 3.10 minver updates. Also, I'm seeing this failure in the tests:
FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype0 - TypeError: missing a required argument: 'amax_history' FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype1 - TypeError: missing a required argument: 'amax_history'could you fix that?
Sorry I missed this test.
I just rebased and fixed this test.
Codecov Report
Attention: Patch coverage is 0% with 17 lines in your changes missing coverage. Please review.
Project coverage is 0.00%. Comparing base (
31adb00) to head (a6f52ae). Report is 46 commits behind head on main.
| Files | Patch % | Lines |
|---|---|---|
| flax/linen/fp8_ops.py | 0.00% | 16 Missing :warning: |
| flax/linen/__init__.py | 0.00% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #3993 +/- ##
======================================
Coverage 0.00% 0.00%
======================================
Files 106 107 +1
Lines 13582 13767 +185
======================================
- Misses 13582 13767 +185
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.