flax
flax copied to clipboard
Add perturb() to allow capturing intermediate gradients
Add a perturb() function in nn.Module, which could be used to capture the intermediate gradients inside a module's call.
See example in the docstring. There will be a future Colab/notebook with better examples.
Looks good! Should we add a simple test?
Looks good! Should we add a simple test?
Test added!
@IvyZX @levskaya we forgot to add this new method to the sphinx page in docs/api_reference/flax.linen.rst, its currently not being rendered.