cutile-python
cutile-python copied to clipboard
Bugfix: support qk_head_dim != v_head_dim in FMHA
The original implementation does not support qk_head_dim != v_head_dim, which is needed in Multi-head Latent Attention. Also fix some test code logic.
Description
The original implementation does not support qk_head_dim != v_head_dim, which is needed in Multi-head Latent Attention. Problem sizes in samples/AttentionFMHA.py are updated s.t. qk_head_dim != v_head_dim and q_num_head != kv_num_head to test a generic GQA case. Parameters and the way calling PyTorch scale_dot_product_attention are also updated to avoid being unable to find a working backend.
All tests have passed locally on a B200.
Checklist
- [x] I am familiar with the Contributing Guidelines.
- [x] New or existing tests cover these changes.
- [x] The documentation is up to date with these changes.