jraph
jraph copied to clipboard
Switch from `jax.tree_multimap` to `jax.tree_map`
Jraph is still using jax.tree_multimap, which is giving a deprecation warning. This can be problematic for users, for instance for us (Flax), since our CI fails if we hit a deprecation warning. i created an exception for Jraph now, but since tree_multimap can be replaced in-place for tree_map, this seems like an easy fix!
See https://github.com/google/flax/issues/2037