CommonLoopUtils
CommonLoopUtils copied to clipboard
CLU lets you write beautiful training loops in JAX.
[bit_dataset] Allow passing the type of image interpolation to perform to the image resizing pre-processing functions.
Starts using `tree_map` from `tree_util` to avoid deprecation warning.
`seed` defaults to None in as_dataset().
Makes Metric/Collection inherit from PyNodeTree. That avoids having to decorate Metric/Collection subclasses with @flax.struct.dataclass and removes one class of confusing errors. Note that all calling code has to be updated...
Make MultimodalEncDecFeatureConverterFactory a dataclasses.
Add tasks for WebLI using Grain and a GetDatasetFn with it.
Always pass checkpoint argument to TfDatasetIterator.
Always pss checkpoint argument to TfDatasetIterator.
Fix silent failure of dataset checkpointing/restoration introduced in cl/472724335.
Add usage logging.