flax icon indicating copy to clipboard operation
flax copied to clipboard

Improve Error Message: Jitting linen.Module

Open marcvanzee opened this issue 5 years ago • 13 comments

Example code:

model = Model(1)

#@jax.jit
def eval_step(model, params, batch):
  logits = model.apply({'params': params}, batch['X'])
  return compute_metrics(logits, batch['y'])

def eval_model(model, params, test_ds):
  metrics = eval_step(model, params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

print(eval_model(model, params, test_ds))

Throws the following general JAX error:

TypeError: Argument 'Model(
    # attributes
    features = 3
)' of type <class '__main__.Model'> is not a valid JAX type

Modules in Linen aren't "pytypes" thus they can't be flattened/unflattened as needed when entering and exiting JAX transformations. The common pattern is to use static_argnums to jit which is equivalent to having the module instance behave like a constant inside the transformed function

We should consider registering Modules with pytrees just to throw an error explaining this.

marcvanzee avatar Jan 15 '21 14:01 marcvanzee

@marcvanzee I could take this. To clarify before I start, is this what you had in mind:

@register_pytree_node_class
class Module

and then define tree_flatten and tree_unflatten inside Module such that they raise an informative error? (Could be TypeError or I could also define a new error type in errors.py)

melissatan avatar Mar 31 '22 10:03 melissatan

@melissatan yes that is the idea 😄 .

marcvanzee avatar Apr 06 '22 09:04 marcvanzee

The naive approach described above doesn't seem to work, at least not out of the box. Upon investigation using a minimal test case

class Foo(nn.Module):
    @compact
    def __call__(self, x):
      return x

x = jnp.ones((3, 2))
foo = Foo()
variables = foo.init(random.PRNGKey(0), x)

@jax.jit
def jitted_step_bad(model, params, x):
  unused_logits = model.apply(params, x)
  return params

_ = jitted_step_bad(foo, variables, x)

I observed that the "is not a valid JAX type" error is being logged from jax/_src/api.py via _cpp_jit()'s

def _check_arg(arg):
  if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
    raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.") # <-- here

Here_valid_jaxtype() looks inside a hashmap called pytype_aval_mappings that lives in jax/interpreters/xla.py for a match. However, that hashmap is different from the _registry mapping that register_pytree_node stores its registry entries into, which lives inside jax/_src/tree_util.py.

The pytype_aval_mappings hashmap in the xla.py file doesn't appear to be populated with the contents of tree_util.py's _registry; the xla.py doesn't even import tree_util.py.

So either:

  • Could it be that register_pytree_node was never intended to be used together with jit? There's no mention of jit in jax/tests/tree_util_test.py.
  • Or maybe they are intended to be compatible but just not implemented yet
  • Or maybe the connection is already implemented somewhere but I haven't been able to find where :)

Just off the top, some potential options to allow Module to get through the _check_arg() check so that we can return our custom Flax error:

  • Modify JAX api.py to import tree_utils and check inside the pytree registry when doing _valid_jaxtype().
  • Modify JAX interpreters/xla.py to take into account the pytree node registry.
  • Attempt something hacky using the if hasattr( '__jax_array__'): part of the _check_arg() check, i.e. try to give nn.Module a __jax_array__ attribute that would pass the check
  • Something else. WDYT? @marcvanzee

melissatan avatar Apr 14 '22 14:04 melissatan

Normally you would use jax.jit(jitted_step_bad, static_argnums=(0,)) so the model is a static argument. You cannot pass a model as a normal arg because it's not a (pytree of) JAX arrays.

jheek avatar Apr 19 '22 07:04 jheek

Hi @jheek , yes I know that I cannot pass a model as a normal arg. I am able to create a test case usingstatic_argnames or static_argnums that works fine without an error. The concern I'm raising is that the currently proposed solution to this issue, ie. registering Module as a pytree node, does not appear to work because any custom error that we add inside Module's tree_flatten/tree_unflatten doesn't trigger at the point where the arg type check is performed.

melissatan avatar Apr 19 '22 08:04 melissatan

@melissatan sorry I read the issue to quickly. I think your issue is that you need to register every Module subclass as a pytree rather than the Module class itself. So something like

class Module():
  @classmethod
  def __init_subclass__(cls, **kwargs: Any) -> None:
    jax.tree_util.register_pytree_node(cls, raise_error_fn, raise_some_error_fn)

jheek avatar Apr 19 '22 10:04 jheek

Thanks @jheek, this worked! At first I tried to pass lambdas as args to register_pytree_node() and was getting a bunch of errors. Explicitly defining the tree_flatten and tree_unflatten got rid of the error.

melissatan avatar Apr 29 '22 11:04 melissatan

Spoke too soon, when I implemented the above I found that the below doctest from module_lifecycle.rst fails:

class Partial(flax.struct.PyTreeNode):
  fn: Callable = flax.struct.field(pytree_node=False)
  args: Iterable[Any]

  def __call__(self, *args, **kwargs):
    return self.fn(*(tuple(self.args) + args), **kwargs)

class Foo(nn.Module):

  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda mdl, x: mdl(x) + 1
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, Partial(fn, [dense]))
  
  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x

.. testcode::
:hide:

x = jax.numpy.ones((3, 2))
mdl = Foo()
vars = mdl.init(random.PRNGKey(0), x)
assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2)

I assume that's because in this particular case, I should not be overwriting tree_util.py's tree_flatten() in the Module __init_subclass__() hook.

Is there some attribute I can use to detect, within __init_subclass__(), whether the Module's __call__() involves a PyTreeNode? I checked dir(mdl) but didn't spot an obvious one.

Here's the error trace:

Document: design_notes/module_lifecycle
---------------------------------------
**********************************************************************
File "design_notes/module_lifecycle.rst", line 691, in default
Failed example:
    x = jax.numpy.ones((3, 2))
    mdl = Foo()
    vars = mdl.init(random.PRNGKey(0), x)
    assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2)
Exception raised:
    Traceback (most recent call last):
      File "<doctest default[0]>", line 3, in <module>
        vars = mdl.init(random.PRNGKey(0), x)
      File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/runner/work/flax/flax/flax/linen/module.py", line 1236, in init
        method=method, mutable=mutable, **kwargs)
      File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/runner/work/flax/flax/flax/linen/module.py", line 1202, in init_with_output
        {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
      File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/runner/work/flax/flax/flax/linen/module.py", line 1169, in apply
        )(variables, *args, **kwargs, rngs=rngs)
      File "/home/runner/work/flax/flax/flax/core/scope.py", line 831, in wrapper
        y = fn(root, *args, **kwargs)
      File "/home/runner/work/flax/flax/flax/linen/module.py", line 1450, in scope_fn
        return fn(module.clone(parent=scope), *args, **kwargs)
      File "/home/runner/work/flax/flax/flax/linen/transforms.py", line 1239, in wrapped_fn
        return prewrapped_fn(self, *args, **kwargs)
      File "/home/runner/work/flax/flax/flax/linen/module.py", line 350, in wrapped_module_method
        return self._call_wrapped_method(fun, args, kwargs)
      File "/home/runner/work/flax/flax/flax/linen/module.py", line 658, in _call_wrapped_method
        y = fun(self, *args, **kwargs)
      File "<doctest default[0]>", line 15, in __call__
      File "/home/runner/work/flax/flax/flax/linen/transforms.py", line 362, in wrapped_fn
        module_scopes, args, kwargs = get_module_scopes(self, args, kwargs)
      File "/home/runner/work/flax/flax/flax/linen/transforms.py", line 138, in get_module_scopes
        new_args, new_kwargs = jax.tree_map(get_arg_scope, (args, kwargs))
      File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/tree_util.py", line 182, in tree_map
        leaves, treedef = tree_flatten(tree, is_leaf)
      File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/tree_util.py", line 54, in tree_flatten
        return pytree.flatten(tree, is_leaf)
      File "/home/runner/work/flax/flax/flax/linen/module.py", line 538, in tree_flatten
        raise errors.JitPytreeError()
    jax._src.traceback_util.UnfilteredStackTrace: flax.errors.JitPytreeError: A Flax Linen Module cannot be passed naively through a jitted transformation since it is not a pytree. To pass it as an arg, use static_argnames or static_argnums. See example in error docstring. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JitPytreeError)

melissatan avatar Apr 29 '22 14:04 melissatan

@jheek it seems like Melissa's change breaks your closure trick. Any ideas on how to resolve this?

marcvanzee avatar May 02 '22 14:05 marcvanzee

This error comes from the fact taht get_module_scopes should do a tree_map with is_leaf=lambda x: isisntance(x, Module) such that Modules are treated as leaves instead of nodes. (And maybe the same should happen in a few other places)

jheek avatar May 03 '22 08:05 jheek

Thanks for responding @jheek, I took a quick look at the code and unfortunately it's not obvious to me yet.

Should the is_leaf tree_map be added to this block in get_module_scopes (below) or somewhere else?:

elif isinstance(x, Module) and isinstance(x.scope, Scope) <-- and is_leaf(x) ?
  x._try_setup(shallow=True)  # pylint: disable=protected-access 
  scopes.append(x.scope) 
  attrs = { f.name: getattr(x, f.name) for f in dataclasses.fields(x) if f.name != 'parent' and f.init }
  attrs = jax.tree_map(get_arg_scope, attrs) 
  return InstancePlaceholder(x.__class__, attrs, id(x)) return x

(And maybe the same should happen in a few other places)

Do we have enough test coverage to catch those other places, given that the only test failure for this PR was the abovementioned doctest from module_lifecycle.rst? If not, could you recommend a good way to find those other places?

Thanks!

melissatan avatar May 24 '22 15:05 melissatan

The is_leaf is passed to the tree_map below so you get: new_args, new_kwargs = jax.tree_map(get_arg_scope, (args, kwargs), is_leaf=lambda x: isinstance(x, nn.Module))

There are 3 calls to tree_map in get_module_scopes and I think they all need the is_leaf check

jheek avatar May 25 '22 07:05 jheek

Was caught up with other responsibilities for some time, picking this back up now.

I've made the changes @jheek suggested, in the 3 tree_map() calls in get_module_scopes(). But this appears to be causing a test failure in several tests that use nn.jit() in linen_transforms_test, e.g.:

def test_compact_aliasing_collision(self):
    class Foo(nn.Module):
      m1: nn.Module
      m2: nn.Module
      @nn.compact
      def __call__(self, x):
        x = self.m2(self.m1(x))
        return x
    class Bar(nn.Module):
      @nn.compact
      def __call__(self, x):
        dense = nn.Dense(2)
        x = nn.jit(Foo)(dense, dense)(x)  # <-- fails here, and outputs the jit pytree error that I'm defining in the PR #2270 .
        return x
    k = random.PRNGKey(0)
    x = jnp.zeros((2, 2))
    _ = Bar().init(k, x)

I tried specifying static_argnums for the nn.jit() call in the test above, but then got another failure about "Non-hashable static arguments are not supported.".

I'm not sufficiently familiar with the internals of linen_transforms. Is there another workaround that prevents the failure in module_lifecycle.rst (https://github.com/google/flax/issues/853#issuecomment-1113396067)?

melissatan avatar Jul 06 '22 15:07 melissatan