ao icon indicating copy to clipboard operation
ao copied to clipboard

Does torchao support FP8 Grouped GEMM?

Open zigzagcai opened this issue 1 year ago • 7 comments

Grouped GEMM kernels (https://github.com/fanshiqing/grouped_gemm) are used in many MoE models.

I just wander does torchao support FP8 kernels for Grouped GEMM, such like the three commonly used ops:

grouped_gemm.backend.gmm
grouped_gemm.ops.unpermute
grouped_gemm.ops.permute

zigzagcai avatar Mar 20 '25 03:03 zigzagcai

hi @zigzagcai , we recently landed a grouped gemm API into core which includes fp8: https://github.com/pytorch/pytorch/pull/148531 . We plan to provide wrappers in torchao, although we do not have them just yet. cc @drisspg

vkuzo avatar Mar 20 '25 12:03 vkuzo

hi @zigzagcai , we recently landed a grouped gemm API into core which includes fp8: https://github.com/pytorch/pytorch/pull/148531 . We plan to provide wrappers in torchao, although we do not have them just yet. cc @drisspg

Thank you @vkuzo ! I just wander how can I use this aten newly needed grouped gemm ops?

zigzagcai avatar Mar 20 '25 14:03 zigzagcai

cc @HDCharles who has been looking into MoE quantization and grouped gemm recently

supriyar avatar Mar 20 '25 18:03 supriyar

Hey,

I'm working to enable our existing quantization kernels to compose with group gemm its still in progress at the moment. As far as the core kernel, you can look at: https://github.com/pytorch/pytorch/pull/148531/files#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R1178

...for an example

HDCharles avatar Mar 20 '25 22:03 HDCharles

@HDCharles @vkuzo

Interested in this as well and potentially helping tune the kernel.

There is a link mentioned in the grouped gemm PR describing the design of the grouped GEMM. How can I view the doc (access seems to be gated)?

jeromeku avatar Mar 22 '25 12:03 jeromeku

any follow ups / responses for this issue? @danielvegamyhre @HDCharles

jerryzh168 avatar May 01 '25 18:05 jerryzh168

You can take a look at the prototype differentiable scaled grouped mm which performs dynamic fp8 rowwise quantization on the inputs then performs a torch._scaled_grouped_mm for the output. Similar steps are performed for the backward pass. It is still in the early stages and is primarily designed for MoE training where the inputs are:

  • "A tensor" (2D input tensor w/ offsets defining the end idx of each group along dim0)
  • "B tensor" (a 3D weights tensor of shape (num_experts, dim1, dim2)

I've added some custom triton kernels to compute per-group scaling factors without host-device sync, which improved perf alot (2x - 6x faster for most shapes). Take a look and let me know what you think!

danielvegamyhre avatar May 01 '25 18:05 danielvegamyhre

@danielvegamyhre Hi, wondering why the scaled groupped mm has been removed?

foreverlms avatar Jul 09 '25 07:07 foreverlms

@danielvegamyhre Hi, wondering why the scaled groupped mm has been removed?

Hi, the prototype was just renamed to moe_training to capture all the other stuff it includes.

danielvegamyhre avatar Jul 09 '25 15:07 danielvegamyhre