maxtext
maxtext copied to clipboard
Add MoE matmul implementation
Description
- Add matmul implementation to replace for loop in MoE models (as a alternative milestone for Megablox due to the blocker summarized here). Eventually, we want to apply Megablox implementation.
- Add a flag
moe_matmulfor now instead of fully replacement, as the current perf is not better than for loop implementation (~40 tflops/s/device vs. ~100 tflops/s/device) without expert sharding (will add a separate PR).
Test
At high level:
- unit test
moe_test.pyto verify output matches - end-to-end test for small model size: result matches
- end-to-end test for original model size: noticed numeric issue
Test links:
- end-to-end for loop (current implementation): test
Input `[INST] I love to [/INST]` -> `That's great to hear! I love to learn new things`
- end-to-end matmul (my change): test
Input `[INST] I love to [/INST]` -> `That's great to hear! Could you please tell me what`