flax icon indicating copy to clipboard operation
flax copied to clipboard

Intermediate value capture API via JAX's `hijax.Box`

Open IvyZX opened this issue 5 months ago • 0 comments

JAX has this new hijax.Box mechanism that can insert arbitrary values during forward & backward passes. This can be a good alternative to the current sow and perturb APIs on Flax level.

The feature is not complete yet (e.g., not yet work with vmap/scan, and in some other corner cases). Gonna do some prototying first.

IvyZX avatar Sep 04 '25 02:09 IvyZX