CommonLoopUtils icon indicating copy to clipboard operation
CommonLoopUtils copied to clipboard

CLU lets you write beautiful training loops in JAX.

Results 42 CommonLoopUtils issues
Sort by recently updated
recently updated
newest added

Update clu.Metrics.compute_value() return type, to avoid signature-mismatch error. according to here: go/pax/metrics#compute-value-return-types

This PR makes `Metric` and `Collection` inherit from `flax.struct.PyTreeNode`, this makes creating new metric types more ergonomic as you no longer need to decorate the class with `flax.struct.dataclass`.

Add point clouds summary writer to tensorboard interface and metric writer.

Rename differentially_private_aggregate -> optax.contrib

Adds `reshuffle_each_iteration` argument to `deterministic_data.create_dataset()`. This argument is passed to `tf.data.Dataset.shuffle()` and controls whether the dataset is reshuffled each time it is iterated over. The default value is `None`, which...

Silence some pytype errors related to a JAX build refactor This build change allows pytype to propagate annotations that it previously did not, and because of this it starts flagging...

Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax