CommonLoopUtils
CommonLoopUtils copied to clipboard
CLU lets you write beautiful training loops in JAX.
Update clu.Metrics.compute_value() return type, to avoid signature-mismatch error. according to here: go/pax/metrics#compute-value-return-types
This PR makes `Metric` and `Collection` inherit from `flax.struct.PyTreeNode`, this makes creating new metric types more ergonomic as you no longer need to decorate the class with `flax.struct.dataclass`.
Warmup input pipeline in the background.
Add point clouds summary writer to tensorboard interface and metric writer.
Make g3pdb lazy again
Make sklearn import lazy
Rename differentially_private_aggregate -> optax.contrib
Adds `reshuffle_each_iteration` argument to `deterministic_data.create_dataset()`. This argument is passed to `tf.data.Dataset.shuffle()` and controls whether the dataset is reshuffled each time it is iterated over. The default value is `None`, which...
Silence some pytype errors related to a JAX build refactor This build change allows pytype to propagate annotations that it previously did not, and because of this it starts flagging...
Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax