[NVIDIA] Add new SDPA API to jax.nn
Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.
This PR proposes introducing a new API in the jax.nn module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.
cc. @nluehr @Cjkkkk @cliffwoolley
@hawkinsp Can you help find reviewers?
Pushed a new commit to remove the use of is_training for the cudnn flash attention. This is a followup of this merged PR.
@sharadmv Any updates?
The API should have an implementation option, taking values like "xla", "cudnn", and None (the default, which selects the best algorithm). This list will grow with alternative kernel implementations (Pallas, etc). It is important to be able to select the implementation type:
-
"cudnn"will fail immediately if there is some unsupported shape, which prevents silent reversions to slow code paths. - Generating serialized models to do inference with on a different device type (eg train on GPU and test on TPU).
Regarding the names: does cuDNN expose both FlashAttention and non-FlashAttention? Perhaps this should be "cudnn_flash"? Note that XLA also has different implementations: we could support the low-memory chunked implementation given here (https://arxiv.org/abs/2112.05682) that inspired FlashAttention, and which is closer numerically to FlashAttention than standard attention and has the same memory complexity (maybe "xla_chunked"? "xla_low_memory"?).
Are there any configuration options a user might want to pass to the cuDNN implementation? If so, it could be a string or a cuDNN config dataclass. Eg. in the low-memory XLA case, the chunk size is something a user might want to configure.
The API should have an
implementationoption, taking values like"xla","cudnn", andNone(the default, which selects the best algorithm). This list will grow with alternative kernel implementations (Pallas, etc). It is important to be able to select the implementation type:
"cudnn"will fail immediately if there is some unsupported shape, which prevents silent reversions to slow code paths.- Generating serialized models to do inference with on a different device type (eg train on GPU and test on TPU).
Regarding the names: does cuDNN expose both FlashAttention and non-FlashAttention? Perhaps this should be
"cudnn_flash"? Note that XLA also has different implementations: we could support the low-memory chunked implementation given here (https://arxiv.org/abs/2112.05682) that inspired FlashAttention, and which is closer numerically to FlashAttention than standard attention and has the same memory complexity (maybe"xla_chunked"?"xla_low_memory"?).Are there any configuration options a user might want to pass to the cuDNN implementation? If so, it could be a string or a cuDNN config dataclass. Eg. in the low-memory XLA case, the chunk size is something a user might want to configure.
Sorry, I think I missed this comment. Do you mean sth like:
def sdpa(..., implementation=None):
if implementation == 'cudnn':
cudnn_sdpa() # users expect to fail on error
elif implementation == 'pallas':
pallas_sdpa() # this is for the future.
elif implementation is None:
# current path of try-except. and will always fall back to `_dot_product_attention_xla`.
Re cudnn flash attentions:
(1) cuDNN used to expose both flash and non-flash attention kernel, but we choose not to use the non-flash anymore. So, the cudnn attention means cudnn flash attention now. And I am ok with the cudnn.
(2) We don't need to pass config to cudnn calls and we are trying to hide it from users.
Sorry, I think I missed this comment. Do you mean sth like:
That looks correct. We have two options here:
- Have multiple SDPA functions, one per backend/implementation.
- Have a single API with the
implementationoption.
There are pros and cons of each, and some tricky questions. For example:
- How closely do numerics need to match in the super-function to be considered 'the same'? As found in this review, cuDNN with bf16 inputs does not cast the first matmul to BF16 before doing softmax, whilst XLA does. If we choose the cuDNN convention, the XLA implementation will be incredibly memory-inefficient. This might be a significant difference in certain applications (eg. training with one but doing inference with the other on a different device-type). With future Pallas kernels, we can match the numerics. But this might be harder for third-party libraries like cuDNN. We might also do autotuning and choose the best kernel with the
Noneoption, which becomes problematic with these numerical differences. This is an argument to have separate functions for third-party kernels that JAX has no control over and are largely opaque (hard to see what numerical choices are being made), and only have a super-function for implementations under JAX-control. - Another argument for separate functions is that the API can be restricted to only the supported features, rather than the most general function imaginable. The current design is makes it hard for users to see what is supported, and limits documentation opportunities. In addition, there are cuDNN specific options (like the philox dropout) unsupported by any other backend, further complicating the API.
I think the name should be dot_product_attention rather than scaled_dot_product_attention. Its also more consistent with Flax naming (https://flax.readthedocs.io/en/v0.8.0/api_reference/flax.linen/_autosummary/flax.linen.dot_product_attention.html).
As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.
Just pushed some new commits for the simplified sdpa. @sbodenstein PTAL.
Also talked to @Cjkkkk and he will try to implement the combination of bias and mask in the cudnn dot_product_attention API (as described here in (1)). When that is in, our logic of preparing bias will be much simpler.
Pushed a few more changes. PTAL. @sbodenstein
Please squash the commits. This will be mergeable as soon as Chris clarifies his comments.
Please squash the commits. This will be mergeable as soon as Chris clarifies his comments.
Sure. Rebased. PTAL. @superbobry
@superbobry: I'm happy with the state of it now. Think we can merge.
Pushed new commits to resolved some failed python lint tests. Btw, can we have the access to add kokoro:force-run label to trigger the tests?
Please squash the commits and we can merge.
Done. PTAL. @superbobry
I still saw this lint error: jax/_src/nn/functions.py:924: error: Argument 4 to "dot_product_attention" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | int | float | complex | None"; expected "Array | None" [arg-type] But I am a bit confused. I think it refers to the mask which I have already converted to Array by jnp.asarray(mask) at the beginning in the function. Do you have any advice on this? @superbobry @sbodenstein
No worries, I'll resolve this internally.
As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.
Thanks for adding FA! Is there a timeline to add dropout support in the SDPA API? I understand it is on hold due to differences in PRNG implementation. Would it be OK if we expose dropout_rate to the API while warning the user on reproducibility if cudnn is selected?
https://github.com/google/jax/blob/417fcd574b9f33410ea8eb78ffdea825ad343eee/jax/_src/cudnn/fused_attention_stablehlo.py#L954-L956
As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.
Thanks for adding FA! Is there a timeline to add
dropoutsupport in the SDPA API? I understand it is on hold due to differences in PRNG implementation. Would it be OK if we exposedropout_rateto the API while warning the user on reproducibility ifcudnnis selected?https://github.com/google/jax/blob/417fcd574b9f33410ea8eb78ffdea825ad343eee/jax/_src/cudnn/fused_attention_stablehlo.py#L954-L956
Yes, this is on our radar to be implemented. Can we know what types of model you are working on that needs the dropout?
Yes, this is on our radar to be implemented. Can we know what types of model you are working on that needs the dropout?
Attention dropout would help for almost all low-data training regimes. Detection Transformers are one well-known example.
Torch supports FA dropout (possibly non-deterministic) in their functional API.