Allow gradient transform parameters to be dynamic
@hawkinsp Pinging you since you recently repaired some type annotation errors. The optimizer classes accepting only float breaks type annotations for the Tjax shim classes (https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/gradient/transforms.py). Tjax provides a parallel set of optimizers, identical in functionality, except they support dynamic optimizer parameters. They do this by storing dynamic fields in a dataclass rather than closing over parameters.
However, the optimizer functionality is delegated to Optax, which means calling Optax update methods with Jax arrays. Is there any reason Optax methods can't accept such arrays? Would it be possible to widen these parameter types to jax.Array | float?
@mtthss Would you mind taking a look at this?
Hello. I was on paternity leave for most of the past year. Are you still having this issue? Happy to look into it if that's the case
@mtthss Hello, yes I'm still getting the type errors. (Congrats on becoming a father!)
which arguments are causing errors to you?
which arguments are causing errors to you?
All of the ones I changed. I maintain a shim library so that I can use optax with dynamic, inspectable parameters. What I ended up doing for the time being is to mark every use of optax with pyright: ignore.
Thanks for taking a look at this.
(Of course, my dream would be that you adopt the dynamic design so that I don't have to maintain my shim library 😄.)