Sebastian Bodenstein
Sebastian Bodenstein
This has been reported before on the OpenMM Github page, with the resolution being to upgrade the machines NVIDIA driver: https://github.com/openmm/openmm/issues/3585 Does this solve your issue?
@nouiz: Any updates about which cuBLAS update will contain the fix?
@nouiz: did this end up being fixed by CUDA 12?
@Jokeren: just how bad is the accuracy?
`triton_call` is already wrapped as a primitive. Let's try fix it at the Pallas/jax-triton level.
I think the explicit DCE solution should be fine. Will let you know if we hit any issues.
The API should have an `implementation` option, taking values like `"xla"`, `"cudnn"`, and `None` (the default, which selects the best algorithm). This list will grow with alternative kernel implementations (Pallas,...
> Sorry, I think I missed this comment. Do you mean sth like: That looks correct. We have two options here: 1. Have multiple SDPA functions, one per backend/implementation. 2....
I think the name should be `dot_product_attention` rather than `scaled_dot_product_attention`. Its also more consistent with Flax naming (https://flax.readthedocs.io/en/v0.8.0/api_reference/flax.linen/_autosummary/flax.linen.dot_product_attention.html).
As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.