maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Add MoE matmul implementation

Open RissyRan opened this issue 1 year ago • 0 comments

Description

  • Add matmul implementation to replace for loop in MoE models (as a alternative milestone for Megablox due to the blocker summarized here). Eventually, we want to apply Megablox implementation.
  • Add a flag moe_matmul for now instead of fully replacement, as the current perf is not better than for loop implementation (~40 tflops/s/device vs. ~100 tflops/s/device) without expert sharding (will add a separate PR).

Test

At high level:

  • unit test moe_test.py to verify output matches
  • end-to-end test for small model size: result matches
  • end-to-end test for original model size: noticed numeric issue

Test links:

  • end-to-end for loop (current implementation): test
Input `[INST] I love to [/INST]` -> `That's great to hear! I love to learn new things`
  • end-to-end matmul (my change): test
Input `[INST] I love to [/INST]` -> `That's great to hear! Could you please tell me what`

RissyRan avatar May 21 '24 04:05 RissyRan