flax icon indicating copy to clipboard operation
flax copied to clipboard

Added support of NANOO fp8.

Open wenchenvincent opened this issue 1 year ago • 7 comments

What does this PR do?

This PR adds support of fp8 dot op for NANOO fp8 data formats (an alternative genre to the OCP fp8 data formats, which is used by NVIDIA GPU).

There are several different genres of fp8 formats used by different HW vendors. Two popular genres include

  • OCP fp8, which is used natively on NVIDIA H100
  • NANOO fp8, which is used natively on AMD MI300 and Graphcore HW.

These two genres of fp8 formats work very similarly. This PR is to enable support of NANOO fp8 as it is also now supported in JAX and XLA. It would enable usage of fp8 dot op on AMD MI300 GPU.

References:

  • OCP fp8 paper: https://arxiv.org/abs/2209.05433
  • NANOO fp8 paper: https://arxiv.org/abs/2206.02915
  • JAX PR: https://github.com/google/jax/pull/21376
  • XLA PR: https://github.com/openxla/xla/pull/9531

wenchenvincent avatar Jun 13 '24 02:06 wenchenvincent

@levskaya I noticed that you have reviewed several PRs regarding fp8. Could you take a look at this one?

wenchenvincent avatar Jun 13 '24 02:06 wenchenvincent

@levskaya Could you kindly serve as the reviewer for this PR?

wenchenvincent avatar Jun 20 '24 01:06 wenchenvincent

@levskaya Thanks for the review. I have updated the PR to address the concerns. Could you take a look at the updates?

wenchenvincent avatar Jun 24 '24 15:06 wenchenvincent

Thanks for the fixes! We may need to do some tiny rebasing of simple things as the codebase just migrated to a python minver of 3.10.

Thanks! Do you need me to rebase it to the tip of the tree?

wenchenvincent avatar Jun 28 '24 04:06 wenchenvincent

Yes to tip as of today should have the 3.10 minver updates. Also, I'm seeing this failure in the tests:

FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype0 - TypeError: missing a required argument: 'amax_history'
FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype1 - TypeError: missing a required argument: 'amax_history'

could you fix that?

levskaya avatar Jun 28 '24 18:06 levskaya

Yes to tip as of today should have the 3.10 minver updates. Also, I'm seeing this failure in the tests:

FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype0 - TypeError: missing a required argument: 'amax_history'
FAILED tests/linen/linen_test.py::Fp8Test::test_fp8_meta_dtype1 - TypeError: missing a required argument: 'amax_history'

could you fix that?

Sorry I missed this test.

I just rebased and fixed this test.

wenchenvincent avatar Jun 28 '24 20:06 wenchenvincent

Codecov Report

Attention: Patch coverage is 0% with 17 lines in your changes missing coverage. Please review.

Project coverage is 0.00%. Comparing base (31adb00) to head (a6f52ae). Report is 46 commits behind head on main.

Files Patch % Lines
flax/linen/fp8_ops.py 0.00% 16 Missing :warning:
flax/linen/__init__.py 0.00% 1 Missing :warning:
Additional details and impacted files
@@          Coverage Diff           @@
##            main   #3993    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files        106     107     +1     
  Lines      13582   13767   +185     
======================================
- Misses     13582   13767   +185     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Jun 28 '24 22:06 codecov-commenter