Shanbin Ke

Results 13 issues of Shanbin Ke

requires replace `async=True` with `non_blocking=True` in file `utils/base.py` since `async` is a reserved keyword in python 3.7

* add `is_training` to API to distinguish between inference and training fwd so cuDNN does not export activation. * fix incorrect checks for seqlen/head_dim * seqlen_k = 256 and seqlen_v...

pull ready

Add option to XLA to enforce inlining before llvm splitModule or set preserveLocals=False to get more balanced splits in parallel compilation case. Some data of GPT3 5B model with different...

* add variable sequence length: Accepts two additional tensor seqlen_q and seqlen_kv to indicate the non padded length to reduce computation. * add MQA/GQA. * add broadcast bias: bias can...

pull ready

* cuDNN SDPA does not support mask input any more, therefore we combine the bias and mask manually to align with public SDPA API design.

pull ready

Support adding arbitrary pointwise operations between bmm1 and softmax to support new variants of cuDNN attention. For example: softcap. An example of generated hlo for inserting soft_cap into cudnn sdpa:...

Add JAX cuDNN SDPA API support in Maxtext, this supports training/prefill/decoding. Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a...