Ivy Zheng
Ivy Zheng
Add a perturb() function in nn.Module, which could be used to capture the intermediate gradients inside a module's __call__. See example in the docstring. There will be a future Colab/notebook...
This allows adding and deleting entries from the `nnx.State` mappings, which is crucial for flexible model surgery.
Goal: to force user to input a `transform_metadata` if they do transform upon annotated variables. Todo: Can we auto-infer the transform axis name from the annotated variables and only throw...
* Moved all the Linen logical axis deduction logic from `linen/spmd.py` to `core/spmd.py`, to be shared with NNX APIs.
JAX has this new `hijax.Box` mechanism that can insert arbitrary values during forward & backward passes. This can be a good alternative to the current `sow` and `perturb` APIs on...
Master tracker: #4924 The whole feature is still WIP. Feedbacks welcome. The `hijax.Box` does not yet work with `vmap` and `scan`. And not in some weird cases like double-jit.
WIP - waiting for both tensorflow and tensorflow_text to release 2.20.0
Python 3.13 has been out for a while, and `tensorflow` already started to support it months ago, since `2.20.0`. But `tensorflow_text` PyPI package is still `2.19.0`. Can we have a...