optax icon indicating copy to clipboard operation
optax copied to clipboard

Allow gradient transform parameters to be dynamic

Open NeilGirdhar opened this issue 2 years ago • 13 comments

NeilGirdhar avatar Mar 22 '23 15:03 NeilGirdhar

@hawkinsp Pinging you since you recently repaired some type annotation errors. The optimizer classes accepting only float breaks type annotations for the Tjax shim classes (https://github.com/NeilGirdhar/tjax/blob/main/tjax/_src/gradient/transforms.py). Tjax provides a parallel set of optimizers, identical in functionality, except they support dynamic optimizer parameters. They do this by storing dynamic fields in a dataclass rather than closing over parameters.

However, the optimizer functionality is delegated to Optax, which means calling Optax update methods with Jax arrays. Is there any reason Optax methods can't accept such arrays? Would it be possible to widen these parameter types to jax.Array | float?

NeilGirdhar avatar Mar 30 '23 06:03 NeilGirdhar

@mtthss Would you mind taking a look at this?

NeilGirdhar avatar May 11 '23 03:05 NeilGirdhar

Hello. I was on paternity leave for most of the past year. Are you still having this issue? Happy to look into it if that's the case

mtthss avatar Oct 10 '23 08:10 mtthss

@mtthss Hello, yes I'm still getting the type errors. (Congrats on becoming a father!)

NeilGirdhar avatar Oct 10 '23 13:10 NeilGirdhar

which arguments are causing errors to you?

mtthss avatar Oct 10 '23 13:10 mtthss

which arguments are causing errors to you?

All of the ones I changed. I maintain a shim library so that I can use optax with dynamic, inspectable parameters. What I ended up doing for the time being is to mark every use of optax with pyright: ignore.

Thanks for taking a look at this.

NeilGirdhar avatar Oct 10 '23 14:10 NeilGirdhar

(Of course, my dream would be that you adopt the dynamic design so that I don't have to maintain my shim library 😄.)

NeilGirdhar avatar Oct 10 '23 14:10 NeilGirdhar