flax icon indicating copy to clipboard operation
flax copied to clipboard

Add split and key methods to RngStream and Rngs

Open samanklesaria opened this issue 6 months ago • 3 comments

Closes #5046

samanklesaria avatar Oct 23 '25 14:10 samanklesaria

Unfortunately, the key attribute on streams already exists. I have renamed the existing attribute key_, and replaced its use with the hand sed command sed -I '' "s/\.key\([^_(]\)/.key_\1/".

samanklesaria avatar Oct 27 '25 19:10 samanklesaria

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Thanks @samanklesaria ! I'll have to run all internal tests before merging.

As a follow up PR, we could consider:

  • replace all usage of jax.random.<function>(rngs(), ...) with rngs.<function>(...). I've done some of this.
  • when we do need a key, favor rngs.key() or rngs.some_stream.key().

cgarciae avatar Oct 30 '25 21:10 cgarciae