jax icon indicating copy to clipboard operation
jax copied to clipboard

[NVIDIA] Add new SDPA API to jax.nn

Open kaixih opened this issue 1 year ago • 10 comments

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the jax.nn module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

kaixih avatar May 22 '24 20:05 kaixih

@hawkinsp Can you help find reviewers?

kaixih avatar May 22 '24 20:05 kaixih

Pushed a new commit to remove the use of is_training for the cudnn flash attention. This is a followup of this merged PR.

kaixih avatar May 31 '24 18:05 kaixih

@sharadmv Any updates?

kaixih avatar Jun 04 '24 16:06 kaixih

The API should have an implementation option, taking values like "xla", "cudnn", and None (the default, which selects the best algorithm). This list will grow with alternative kernel implementations (Pallas, etc). It is important to be able to select the implementation type:

  • "cudnn" will fail immediately if there is some unsupported shape, which prevents silent reversions to slow code paths.
  • Generating serialized models to do inference with on a different device type (eg train on GPU and test on TPU).

Regarding the names: does cuDNN expose both FlashAttention and non-FlashAttention? Perhaps this should be "cudnn_flash"? Note that XLA also has different implementations: we could support the low-memory chunked implementation given here (https://arxiv.org/abs/2112.05682) that inspired FlashAttention, and which is closer numerically to FlashAttention than standard attention and has the same memory complexity (maybe "xla_chunked"? "xla_low_memory"?).

Are there any configuration options a user might want to pass to the cuDNN implementation? If so, it could be a string or a cuDNN config dataclass. Eg. in the low-memory XLA case, the chunk size is something a user might want to configure.

sbodenstein avatar Jun 07 '24 14:06 sbodenstein

The API should have an implementation option, taking values like "xla", "cudnn", and None (the default, which selects the best algorithm). This list will grow with alternative kernel implementations (Pallas, etc). It is important to be able to select the implementation type:

  • "cudnn" will fail immediately if there is some unsupported shape, which prevents silent reversions to slow code paths.
  • Generating serialized models to do inference with on a different device type (eg train on GPU and test on TPU).

Regarding the names: does cuDNN expose both FlashAttention and non-FlashAttention? Perhaps this should be "cudnn_flash"? Note that XLA also has different implementations: we could support the low-memory chunked implementation given here (https://arxiv.org/abs/2112.05682) that inspired FlashAttention, and which is closer numerically to FlashAttention than standard attention and has the same memory complexity (maybe "xla_chunked"? "xla_low_memory"?).

Are there any configuration options a user might want to pass to the cuDNN implementation? If so, it could be a string or a cuDNN config dataclass. Eg. in the low-memory XLA case, the chunk size is something a user might want to configure.

Sorry, I think I missed this comment. Do you mean sth like:

def sdpa(..., implementation=None):
  if implementation == 'cudnn':
    cudnn_sdpa() # users expect to fail on error
  elif implementation == 'pallas':
    pallas_sdpa() # this is for the future.
  elif implementation is None:
    # current path of try-except. and will always fall back to `_dot_product_attention_xla`.

Re cudnn flash attentions: (1) cuDNN used to expose both flash and non-flash attention kernel, but we choose not to use the non-flash anymore. So, the cudnn attention means cudnn flash attention now. And I am ok with the cudnn. (2) We don't need to pass config to cudnn calls and we are trying to hide it from users.

kaixih avatar Jun 12 '24 17:06 kaixih

Sorry, I think I missed this comment. Do you mean sth like:

That looks correct. We have two options here:

  1. Have multiple SDPA functions, one per backend/implementation.
  2. Have a single API with the implementation option.

There are pros and cons of each, and some tricky questions. For example:

  • How closely do numerics need to match in the super-function to be considered 'the same'? As found in this review, cuDNN with bf16 inputs does not cast the first matmul to BF16 before doing softmax, whilst XLA does. If we choose the cuDNN convention, the XLA implementation will be incredibly memory-inefficient. This might be a significant difference in certain applications (eg. training with one but doing inference with the other on a different device-type). With future Pallas kernels, we can match the numerics. But this might be harder for third-party libraries like cuDNN. We might also do autotuning and choose the best kernel with the None option, which becomes problematic with these numerical differences. This is an argument to have separate functions for third-party kernels that JAX has no control over and are largely opaque (hard to see what numerical choices are being made), and only have a super-function for implementations under JAX-control.
  • Another argument for separate functions is that the API can be restricted to only the supported features, rather than the most general function imaginable. The current design is makes it hard for users to see what is supported, and limits documentation opportunities. In addition, there are cuDNN specific options (like the philox dropout) unsupported by any other backend, further complicating the API.

sbodenstein avatar Jun 13 '24 14:06 sbodenstein

I think the name should be dot_product_attention rather than scaled_dot_product_attention. Its also more consistent with Flax naming (https://flax.readthedocs.io/en/v0.8.0/api_reference/flax.linen/_autosummary/flax.linen.dot_product_attention.html).

sbodenstein avatar Jun 14 '24 12:06 sbodenstein

As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.

sbodenstein avatar Jun 24 '24 17:06 sbodenstein

Just pushed some new commits for the simplified sdpa. @sbodenstein PTAL.

Also talked to @Cjkkkk and he will try to implement the combination of bias and mask in the cudnn dot_product_attention API (as described here in (1)). When that is in, our logic of preparing bias will be much simpler.

kaixih avatar Jun 25 '24 00:06 kaixih

Pushed a few more changes. PTAL. @sbodenstein

kaixih avatar Jun 28 '24 21:06 kaixih

Please squash the commits. This will be mergeable as soon as Chris clarifies his comments.

superbobry avatar Jul 05 '24 08:07 superbobry

Please squash the commits. This will be mergeable as soon as Chris clarifies his comments.

Sure. Rebased. PTAL. @superbobry

kaixih avatar Jul 05 '24 17:07 kaixih

@superbobry: I'm happy with the state of it now. Think we can merge.

sbodenstein avatar Jul 05 '24 17:07 sbodenstein

Pushed new commits to resolved some failed python lint tests. Btw, can we have the access to add kokoro:force-run label to trigger the tests?

kaixih avatar Jul 05 '24 19:07 kaixih

Please squash the commits and we can merge.

superbobry avatar Jul 07 '24 10:07 superbobry

Done. PTAL. @superbobry

kaixih avatar Jul 07 '24 16:07 kaixih

I still saw this lint error: jax/_src/nn/functions.py:924: error: Argument 4 to "dot_product_attention" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | int | float | complex | None"; expected "Array | None" [arg-type] But I am a bit confused. I think it refers to the mask which I have already converted to Array by jnp.asarray(mask) at the beginning in the function. Do you have any advice on this? @superbobry @sbodenstein

kaixih avatar Jul 07 '24 23:07 kaixih

No worries, I'll resolve this internally.

superbobry avatar Jul 08 '24 09:07 superbobry

As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.

Thanks for adding FA! Is there a timeline to add dropout support in the SDPA API? I understand it is on hold due to differences in PRNG implementation. Would it be OK if we expose dropout_rate to the API while warning the user on reproducibility if cudnn is selected?

https://github.com/google/jax/blob/417fcd574b9f33410ea8eb78ffdea825ad343eee/jax/_src/cudnn/fused_attention_stablehlo.py#L954-L956

MasterSkepticista avatar Aug 16 '24 06:08 MasterSkepticista

As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.

Thanks for adding FA! Is there a timeline to add dropout support in the SDPA API? I understand it is on hold due to differences in PRNG implementation. Would it be OK if we expose dropout_rate to the API while warning the user on reproducibility if cudnn is selected?

https://github.com/google/jax/blob/417fcd574b9f33410ea8eb78ffdea825ad343eee/jax/_src/cudnn/fused_attention_stablehlo.py#L954-L956

Yes, this is on our radar to be implemented. Can we know what types of model you are working on that needs the dropout?

kaixih avatar Aug 27 '24 16:08 kaixih

Yes, this is on our radar to be implemented. Can we know what types of model you are working on that needs the dropout?

Attention dropout would help for almost all low-data training regimes. Detection Transformers are one well-known example.

Torch supports FA dropout (possibly non-deterministic) in their functional API.

MasterSkepticista avatar Aug 28 '24 04:08 MasterSkepticista