Type errors for general pytrees
Hello!
When it comes to annotations, optax currently relies heavily on optax.Updates and optax.Params, which are all aliases for chex.ArrayTree.
This makes sense, but for folks who run type checkers means that a lot of type errors happen when working with pytrees that aren't strictly nested Iterable or Mapping types as specified in chex. For example:
from typing import Tuple
import optax
from jax import numpy as jnp
import flax.struct
@flax.struct.dataclass
class Params:
weights: jnp.ndarray
bias: jnp.ndarray
def make_optimizer(
params: Params,
) -> Tuple[optax.GradientTransformation, optax.OptState]:
"""Make an optimizer."""
optimizer = optax.sgd(learning_rate=1e-3)
state = optimizer.init(params) # Type error.
return optimizer, state
A few questions from this:
- Is this considered a bug, or something that the optax team would be open to supporting? Are there better solutions for suppressing this error than simply adding a
# type: ignore? - It seems like type safety with optax could benefit immensely from support for generics, which have been present since Python 3.5 (
typing.Generic,typing.TypeVar). Any chance this would be something that optax would be open to supporting?- Simple example: with
ArrayTreeT = TypeVar("ArrayTreeT", bound=chex.ArrayTree),optax.apply_updates()could be annotated asoptax.apply_updates(params: ArrayTreeT, updates: ArrayTreeT) -> ArrayTreeTto indicate that the argument and return types should all be the same.
- Simple example: with
Hi,
thanks a lot for pointing this out! This is definitely something we should discuss especially if it would be convenient for flax to have these types supported.
I think we should prefer to stick to chex types as the standard to make it easier to ensure safe interoperability with other jax libraries that use chex (e.g. this bug isn't directly related to typing but it shows the bugs that can arise by differences in how the libraries treat more complicated pytrees). I think we should avoid defining our own versions of common types in optax if possible.
@hbq1 : has there been a discussion in chex on extending ArrayTree to include some (common) dataclass implementations?