Shanbin Ke
Shanbin Ke
requires replace `async=True` with `non_blocking=True` in file `utils/base.py` since `async` is a reserved keyword in python 3.7
* add `is_training` to API to distinguish between inference and training fwd so cuDNN does not export activation. * fix incorrect checks for seqlen/head_dim * seqlen_k = 256 and seqlen_v...
Add option to XLA to enforce inlining before llvm splitModule or set preserveLocals=False to get more balanced splits in parallel compilation case. Some data of GPT3 5B model with different...
* add variable sequence length: Accepts two additional tensor seqlen_q and seqlen_kv to indicate the non padded length to reduce computation. * add MQA/GQA. * add broadcast bias: bias can...
* cuDNN SDPA does not support mask input any more, therefore we combine the bias and mask manually to align with public SDPA API design.
Support adding arbitrary pointwise operations between bmm1 and softmax to support new variants of cuDNN attention. For example: softcap. An example of generated hlo for inserting soft_cap into cudnn sdpa:...
Add JAX cuDNN SDPA API support in Maxtext, this supports training/prefill/decoding. Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a...