flax
flax copied to clipboard
Flax is a neural network library for JAX that is designed for flexibility.
Update minimal JAX version to latest (0.3.16)
equations rendered incorrectly in Guides # What does this PR do? fixes math expressions rendered incorrectly in jax for the impatient and flax basics
# What does this PR do? Adds a `path_value_map` function to `traverse_util`, this function makes it easy to use create functions that take in Flax's `variables` structures and output path...
RNN FLIP
# RNN Flip - Start Date: 2022-08-18 - FLIP PR: N/A - FLIP Issue: [#2396](https://github.com/google/flax/issues/2396) - Authors: Jasmijn Bastings (@bastings) and Cristian Garcia (@cgarciae) ## Summary This FLIP proposes a...
Plumb spmd_axis_name from vmap_with_axes through to JAX vmap
# What does this PR do? Adds a Transfer Learning guide which includes: * Loading a pre-trained model * Doing parameter surgery * Freezing layers and implementing Differential Learning Rates...
Internal
Internal
# What does this PR do? Overhauls `tabulate` with a new system to capture call information and fixes #2274 and #2359. ### Changes * Adds a new thread-local `_CallInfoContext` context...
# What does this PR do? Testing... DO NOT MERGE
Add optional "strict" kwarg to restore_checkpoint to require the existence of checkpoint files. Currently we return the target / template object unmodified if no checkpoint directory or file exists. This...