Add JAX cuDNN SDPA API support in Maxtext
Add JAX cuDNN SDPA API support in Maxtext, this supports training/prefill/decoding.
Checklist Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed.
Can we add some unit test here protecting this? Just running this code without asserting anything is fine
Thank you for adding this feature!
+1 on adding tests
For training tests on GPUs, you can take this as an example: https://github.com/AI-Hypercomputer/maxtext/blob/9f62bc4a4858dbecc6a9d9b6aea6efeb1e28b0d5/MaxText/tests/train_tests.py#L194.
Can we add some unit test here protecting this? Just running this code without asserting anything is fine
+1
Can we add some unit test here protecting this? Just running this code without asserting anything is fine
Thank you for adding this feature!
+1 on adding tests
For training tests on GPUs, you can take this as an example:
https://github.com/AI-Hypercomputer/maxtext/blob/9f62bc4a4858dbecc6a9d9b6aea6efeb1e28b0d5/MaxText/tests/train_tests.py#L194
.
I added one example, when I did python -m unittest MaxText.tests.train_tests.TrainTests.test_gpu_cudnn_flash_jax
I got
FAILED (errors=1)
Exception ignored in: <function GoodputMonitor.__del__ at 0x7f8ee7fd5e40>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/ml_goodput_measurement/src/monitoring.py", line 152, in __del__
if self._uploader_thread_running:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'GoodputMonitor' object has no attribute '_uploader_thread_running'
Any idea what is missing here?