maxtext
maxtext copied to clipboard
Fix the dtype when calling Moe megablox gmm kernel.
Description
When calling the MoE megablox gmm kernle, the preferred_element_type is hardcoded to bfloat16. Replace with config controlled dtype instead of hardcoding.
Tests
No functional change, reply on presubmit test.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed.