flax
flax copied to clipboard
Add split and key methods to RngStream and Rngs
Closes #5046
Unfortunately, the key attribute on streams already exists. I have renamed the existing attribute key_, and replaced its use with the hand sed command sed -I '' "s/\.key\([^_(]\)/.key_\1/".
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Thanks @samanklesaria ! I'll have to run all internal tests before merging.
As a follow up PR, we could consider:
- replace all usage of
jax.random.<function>(rngs(), ...)withrngs.<function>(...). I've done some of this. - when we do need a key, favor
rngs.key()orrngs.some_stream.key().