flax icon indicating copy to clipboard operation
flax copied to clipboard

flax.linen.module.init still fails under dynamic type checking for nested modules

Open rainx0r opened this issue 1 year ago • 0 comments

Related issue: #3224

While the snippet posted in that issue does work now, there still seems to be a failure mode when nested modules (all of which are runtime type checked) are used.

Colab Link

from jax import numpy as jnp
import flax.linen as nn
import jax
from beartype import beartype

from jaxtyping import jaxtyped

@jaxtyped(typechecker=beartype)
class MyModuleInternal(nn.Module):
    hidden_size: int = 2

    @nn.compact
    def __call__(self, x):
      return nn.Dense(self.hidden_size)(x)


@jaxtyped(typechecker=beartype)
class MyModule(nn.Module):
    hidden_dim: int

    def setup(self) -> None:
        self.internal_module = MyModuleInternal(self.hidden_dim)  # <-- failure here
  
    def __call__(self, x):
        x = self.internal_module(x)
        return x


model = MyModule(5)

params = model.init(
    rngs={"params": jax.random.PRNGKey(0)},
    x=jnp.ones((1, 1)),
)

This snippet fails with the following error:

---------------------------------------------------------------------------
BeartypeCallHintParamViolation            Traceback (most recent call last)
    [... skipping hidden 1 frame]

<@beartype(__main__.check_params) at 0x7d4c408e2830> in check_params(__beartype_get_violation, __beartype_conf, __beartype_object_137766497224064, __beartype_object_99821132912832, __beartype_object_99821132891488, __beartype_object_137766477140992, __beartype_func, *args, **kwargs)

BeartypeCallHintParamViolation: Method __main__.check_params() parameter parent="MyModule(
    # attributes
    hidden_dim = 5
)" violates type hint typing.Union[typing.Type[flax.linen.module.Module], flax.core.scope.Scope, typing.Type[flax.linen.module._Sentinel], NoneType]

Looking at nn.Module's _ParentType, indeed the type of the argument to parent is expected to be Type[nn.Module] so a class, rather than an instance of nn.Module which is what is actually being passed in. This seems to have been the problem for the previously reported instance of this issue in #3224 , since the PR that fixes it (#3371) changed the type annotation from Type[Scope] to simply Scope, to adjust the expectation from a class being provided to an instance.

rainx0r avatar Mar 13 '24 15:03 rainx0r