rfcs: multi-head attention
Description
A link to the rendered document: link
Fixes # (github issue)
Checklist
General
- [ ] Do all unit and benchdnn tests (
make testandmake test_benchdnn_*) pass locally for each commit? - [ ] Have you formatted the code using clang-format?
Performance improvements
- [ ] Have you submitted performance data that demonstrates performance improvements?
New features
- [ ] Have you published an RFC for the new feature?
- [ ] Was the RFC approved?
- [ ] Have you added relevant tests?
Bug fixes
- [ ] Have you included information on how to reproduce the issue (either in a github issue or in this PR)?
- [ ] Have you added relevant regression tests?
RFC PR
- [ ] Does RFC document follow the template?
- [ ] Have you added a link to the rendered document?
Hi @igorsafo, since GC MHA patterns have a lot of variation, what would such an API look like? Thanks!
Hi @igorsafo, since GC MHA patterns have a lot of variation, what would such an API look like? Thanks!
Since there is no primitive API planned there is no need for an additional API. Frameworks will build graphs using oneDNN Graph API and it will match the SDPA patterns and will create GC or oneDNN primitive-based implementation as a partition.
Thanks for the info, @igorsafo!
PyTorch graph path has now moved to matching patterns on the framework side. So, this approach requires oneDNN Graph patterns to be hard-coded, and a PyTorch op can be called to compile/execute partitions. Looks like this RFC wouldn't change our approach.
PyTorch graph path has now moved to matching patterns on the framework side.
This approach might result in a lot of issues in the future, so I would highly recommend to discuss it further with the PyTorch maintainers before we put a lot of effort on Graph API integration. Since Graph patterns contain many more ops (vs primitive + post_op) any change within a pattern will result in a miss. Also if we add an optimization to a new pattern this will not be catched by the framework pattern matcher until it is updated explicitely in the framework codebase. The main benefit of Graph API is that an application/framework can throw a graph into it and oneDNN will return optimized partitions which is not possible with oneDNN primitives. In primitive API user has to know what primitives and post ops are supported. A pattern matcher on the framework side removes this benefit.
+@TaoLv @vpirogov
Hi @igorsafo, I agree that such an approach requires more work on the framework side since oneDNN Graph partitions need to be hardcoded in the framework with such an implementation.
Meta would like to follow the approach of letting the framework decide which pattern to offload to another library, which runs contrary to oneDNN Graph's ease-of-use principles that allow full graph to be passed to oneDNN Graph, which can then ascertain which patterns it can support.