optax icon indicating copy to clipboard operation
optax copied to clipboard

Optax is a gradient processing and optimization library for JAX.

Results 153 optax issues
Sort by recently updated
recently updated
newest added

Reference: #594 This PR aims to add a new `GradientTransformation` which allows to scale by the gradient norm. Request for Review: @fabianp

**Feature request:** Add the [**ACProp**](https://juntang-zhuang.github.io/acprop/index.html) optimizer. I've created a PR for this: #966.

Good morning! The "masked" function finds the "mask_tree" as follows: mask_tree = mask(params) if callable(mask) else mask (1) Which is used twice (in "init_fn" and "update_fn"). However, in some cases,...

I added an epsilon value to the cosine similarity function to avoid the NaNs that were occurring when when you had a label vector [0, 0, 0] or when one...

### Motivation Recently for a personal project (#641) I wanted to created a custom `GradientTransformation` object and found out that there wasn't a tutorial or walkthrough on how to create...

documentation

**Feature request:** Add a GPU/TPU-friendly solver for the [assignment problem](https://en.wikipedia.org/wiki/Assignment_problem). For context, see: 1. [scipy.optimize.linear_sum_assignment](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html) 2. https://github.com/google/jax/issues/10403 3. https://github.com/google/jax/pull/16974 The last page contains the following comment: > There is a...

enhancement