optax
optax copied to clipboard
Optax is a gradient processing and optimization library for JAX.
Reference: #594 This PR aims to add a new `GradientTransformation` which allows to scale by the gradient norm. Request for Review: @fabianp
#965
**Feature request:** Add the [**ACProp**](https://juntang-zhuang.github.io/acprop/index.html) optimizer. I've created a PR for this: #966.
move clipping transforms to optax.transforms.
Good morning! The "masked" function finds the "mask_tree" as follows: mask_tree = mask(params) if callable(mask) else mask (1) Which is used twice (in "init_fn" and "update_fn"). However, in some cases,...
RST formatting in mechanic
I added an epsilon value to the cosine similarity function to avoid the NaNs that were occurring when when you had a label vector [0, 0, 0] or when one...
### Motivation Recently for a personal project (#641) I wanted to created a custom `GradientTransformation` object and found out that there wasn't a tutorial or walkthrough on how to create...
Remove useless inner jit
**Feature request:** Add a GPU/TPU-friendly solver for the [assignment problem](https://en.wikipedia.org/wiki/Assignment_problem). For context, see: 1. [scipy.optimize.linear_sum_assignment](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html) 2. https://github.com/google/jax/issues/10403 3. https://github.com/google/jax/pull/16974 The last page contains the following comment: > There is a...