CommonLoopUtils
CommonLoopUtils copied to clipboard
CLU lets you write beautiful training loops in JAX.
Implement recent changes to the DatasetIterator interface.
Implement __repr__() for several classes.
Add DataIterator and checkpoint handler for Grain.
This fix adjusts import to avoid accidentally pulling in keras in clu. https://github.com/google/CommonLoopUtils/pull/336
Change deprecated `jax.tree_map` to avoid warnings: ``` DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version). ```
Support sharding for get_parameter_overview
i am moving some of my models to keras3 and using jax backend for training. i am on a slightly older tensorboard (2.12) so when it pulls in keras callback...
Disable typeguard for TF
Fix mock_modules for `import x.y.z as a`
Fix type annotations to pass pytype checks.