Add split-K heuristic for decode attention
Summary: This diff adds an automatic split-K size heuristic for the Blackwell FMHA decode kernel to optimize GPU utilization.
Added get_splitk_heuristic() that automatically computes optimal split-K size .
The heuristic ensures split sizes are multiples of TileN (256) and disables split-K when only 1 split would occur.
Performance Benchmarks show consistent 15-34% speedup over Triton split-K across all tested configurations:
- Average speedup: 1.24x
- Min speedup: 1.15x (Batch=16, SeqLen=32768)
- Max speedup: 1.34x (Batch=128, SeqLen=8192)
Reviewed By: jianyuh
Differential Revision: D89016012
@Aya-ZIbra has exported this pull request. If you are a Meta employee, you can view the originating Diff in D89016012.
This pull request has been merged in pytorch/FBGEMM@3086dd201373085b01da748f663285f98b1572c8.