maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Add JAX cuDNN SDPA API support in Maxtext

Open Cjkkkk opened this issue 9 months ago • 3 comments

Add JAX cuDNN SDPA API support in Maxtext, this supports training/prefill/decoding.

Checklist Before submitting this PR, please make sure (put X in square brackets):

  • [x] I have performed a self-review of my code.
  • [x] I have necessary comments in my code, particularly in hard-to-understand areas.
  • [x] I have run end-to-end tests tests and provided workload links above if applicable.
  • [x] I have made or will make corresponding changes to the doc if needed.

Cjkkkk avatar Apr 23 '25 21:04 Cjkkkk

Can we add some unit test here protecting this? Just running this code without asserting anything is fine

Thank you for adding this feature!

+1 on adding tests

For training tests on GPUs, you can take this as an example: https://github.com/AI-Hypercomputer/maxtext/blob/9f62bc4a4858dbecc6a9d9b6aea6efeb1e28b0d5/MaxText/tests/train_tests.py#L194.

yangyuwei avatar Apr 23 '25 23:04 yangyuwei

Can we add some unit test here protecting this? Just running this code without asserting anything is fine

+1

bvandermoon avatar Apr 24 '25 08:04 bvandermoon

Can we add some unit test here protecting this? Just running this code without asserting anything is fine

Thank you for adding this feature!

+1 on adding tests

For training tests on GPUs, you can take this as an example:

https://github.com/AI-Hypercomputer/maxtext/blob/9f62bc4a4858dbecc6a9d9b6aea6efeb1e28b0d5/MaxText/tests/train_tests.py#L194

.

I added one example, when I did python -m unittest MaxText.tests.train_tests.TrainTests.test_gpu_cudnn_flash_jax I got

FAILED (errors=1)
Exception ignored in: <function GoodputMonitor.__del__ at 0x7f8ee7fd5e40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/ml_goodput_measurement/src/monitoring.py", line 152, in __del__
    if self._uploader_thread_running:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'GoodputMonitor' object has no attribute '_uploader_thread_running'

Any idea what is missing here?

Cjkkkk avatar Apr 24 '25 21:04 Cjkkkk