Does torchao support FP8 Grouped GEMM?
Grouped GEMM kernels (https://github.com/fanshiqing/grouped_gemm) are used in many MoE models.
I just wander does torchao support FP8 kernels for Grouped GEMM, such like the three commonly used ops:
grouped_gemm.backend.gmm
grouped_gemm.ops.unpermute
grouped_gemm.ops.permute
hi @zigzagcai , we recently landed a grouped gemm API into core which includes fp8: https://github.com/pytorch/pytorch/pull/148531 . We plan to provide wrappers in torchao, although we do not have them just yet. cc @drisspg
hi @zigzagcai , we recently landed a grouped gemm API into core which includes fp8: https://github.com/pytorch/pytorch/pull/148531 . We plan to provide wrappers in torchao, although we do not have them just yet. cc @drisspg
Thank you @vkuzo ! I just wander how can I use this aten newly needed grouped gemm ops?
cc @HDCharles who has been looking into MoE quantization and grouped gemm recently
Hey,
I'm working to enable our existing quantization kernels to compose with group gemm its still in progress at the moment. As far as the core kernel, you can look at: https://github.com/pytorch/pytorch/pull/148531/files#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R1178
...for an example
@HDCharles @vkuzo
Interested in this as well and potentially helping tune the kernel.
There is a link mentioned in the grouped gemm PR describing the design of the grouped GEMM. How can I view the doc (access seems to be gated)?
any follow ups / responses for this issue? @danielvegamyhre @HDCharles
You can take a look at the prototype differentiable scaled grouped mm which performs dynamic fp8 rowwise quantization on the inputs then performs a torch._scaled_grouped_mm for the output. Similar steps are performed for the backward pass. It is still in the early stages and is primarily designed for MoE training where the inputs are:
- "A tensor" (2D input tensor w/ offsets defining the end idx of each group along dim0)
- "B tensor" (a 3D weights tensor of shape (num_experts, dim1, dim2)
I've added some custom triton kernels to compute per-group scaling factors without host-device sync, which improved perf alot (2x - 6x faster for most shapes). Take a look and let me know what you think!
@danielvegamyhre Hi, wondering why the scaled groupped mm has been removed?
@danielvegamyhre Hi, wondering why the scaled groupped mm has been removed?
Hi, the prototype was just renamed to moe_training to capture all the other stuff it includes.