oliver
oliver
@sokrypton Of course, I've made a pull request in ColabDesign with it (https://github.com/sokrypton/ColabDesign/pull/173)
Pre https://github.com/google-deepmind/alphafold/pull/931/commits/d4516d83aaf65aee2e2c90ca85b86acacd464c0f I find transient NaN behaviour on shapes which don't evenly divide block size (so OOB loading). [gist](https://gist.github.com/oliverdutton/98c468dccfc5dcc3f0f2c0b793f46bb2) to reproduce problem: ```python import jax from jax import jit, numpy...
@sokrypton I think this is ready for merging. It's still strictly opt-in (as Pallas with Triton is only available for Ampere architecture GPUs and up) You could improve performance a...
This corrects similar issues to https://github.com/google/jax/pull/21180 @justinjfu, though relating to indexing into MemRefs rather than non evenly-divisible block shapes for chunking arrays
This PR is now solely aimed at fixing dynamic slice store/load operations in the interpreter identified in https://github.com/google/jax/issues/21143. In the interpreter it pads the arrays with uninitialised values so dynamic...
Comments implemented and now should be ready