jax
jax copied to clipboard
[pallas] align interpreter load/store with masked behaviour (and adds stride support)
Matches behaviour in Triton where for load/store/swap any masked indexing does not occur.
For load it sets any masked indexing to index to the first element in the array instead. For swap(/store) it also sets masked indexing to the first element (and then deals with special rules to make sure the first element is dealt with correctly)
The currently used dynamic_slices are replaced with explicit index materialisation and gathers/scatters.
The advantage of doing it this way is that you can combine it with checkify(f, errors=checkify.index_checks) in interpreter mode to check for any unmasked OOB indexing which is (I think, and believe should be) undefined behaviour.