jax icon indicating copy to clipboard operation
jax copied to clipboard

[pallas] align interpreter load/store with masked behaviour (and adds stride support)

Open oliverdutton opened this issue 1 year ago • 0 comments

Matches behaviour in Triton where for load/store/swap any masked indexing does not occur.

For load it sets any masked indexing to index to the first element in the array instead. For swap(/store) it also sets masked indexing to the first element (and then deals with special rules to make sure the first element is dealt with correctly)

The currently used dynamic_slices are replaced with explicit index materialisation and gathers/scatters.

The advantage of doing it this way is that you can combine it with checkify(f, errors=checkify.index_checks) in interpreter mode to check for any unmasked OOB indexing which is (I think, and believe should be) undefined behaviour.

oliverdutton avatar May 09 '24 09:05 oliverdutton