flax
flax copied to clipboard
Intermediate value capture API via JAX's `hijax.Box`
JAX has this new hijax.Box mechanism that can insert arbitrary values during forward & backward passes. This can be a good alternative to the current sow and perturb APIs on Flax level.
The feature is not complete yet (e.g., not yet work with vmap/scan, and in some other corner cases). Gonna do some prototying first.