Improve Error Message: Jitting linen.Module
Example code:
model = Model(1)
#@jax.jit
def eval_step(model, params, batch):
logits = model.apply({'params': params}, batch['X'])
return compute_metrics(logits, batch['y'])
def eval_model(model, params, test_ds):
metrics = eval_step(model, params, test_ds)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
return summary['loss'], summary['accuracy']
print(eval_model(model, params, test_ds))
Throws the following general JAX error:
TypeError: Argument 'Model(
# attributes
features = 3
)' of type <class '__main__.Model'> is not a valid JAX type
Modules in Linen aren't "pytypes" thus they can't be flattened/unflattened as needed when entering and exiting JAX transformations. The common pattern is to use static_argnums to jit which is equivalent to having the module instance behave like a constant inside the transformed function
We should consider registering Modules with pytrees just to throw an error explaining this.
@marcvanzee I could take this. To clarify before I start, is this what you had in mind:
@register_pytree_node_class
class Module
and then define tree_flatten and tree_unflatten inside Module such that they raise an informative error? (Could be TypeError or I could also define a new error type in errors.py)
@melissatan yes that is the idea 😄 .
The naive approach described above doesn't seem to work, at least not out of the box. Upon investigation using a minimal test case
class Foo(nn.Module):
@compact
def __call__(self, x):
return x
x = jnp.ones((3, 2))
foo = Foo()
variables = foo.init(random.PRNGKey(0), x)
@jax.jit
def jitted_step_bad(model, params, x):
unused_logits = model.apply(params, x)
return params
_ = jitted_step_bad(foo, variables, x)
I observed that the "is not a valid JAX type" error is being logged from jax/_src/api.py via _cpp_jit()'s
def _check_arg(arg):
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.") # <-- here
Here_valid_jaxtype() looks inside a hashmap called pytype_aval_mappings that lives in jax/interpreters/xla.py for a match.
However, that hashmap is different from the _registry mapping that register_pytree_node stores its registry entries into, which lives inside jax/_src/tree_util.py.
The pytype_aval_mappings hashmap in the xla.py file doesn't appear to be populated with the contents of tree_util.py's _registry; the xla.py doesn't even import tree_util.py.
So either:
- Could it be that
register_pytree_nodewas never intended to be used together with jit? There's no mention of jit in jax/tests/tree_util_test.py. - Or maybe they are intended to be compatible but just not implemented yet
- Or maybe the connection is already implemented somewhere but I haven't been able to find where :)
Just off the top, some potential options to allow Module to get through the _check_arg() check so that we can return our custom Flax error:
- Modify JAX api.py to import tree_utils and check inside the pytree registry when doing
_valid_jaxtype(). - Modify JAX interpreters/xla.py to take into account the pytree node registry.
- Attempt something hacky using the
if hasattr( '__jax_array__'):part of the_check_arg()check, i.e. try to give nn.Module a__jax_array__attribute that would pass the check - Something else. WDYT? @marcvanzee
Normally you would use jax.jit(jitted_step_bad, static_argnums=(0,)) so the model is a static argument. You cannot pass a model as a normal arg because it's not a (pytree of) JAX arrays.
Hi @jheek , yes I know that I cannot pass a model as a normal arg. I am able to create a test case usingstatic_argnames or static_argnums that works fine without an error. The concern I'm raising is that the currently proposed solution to this issue, ie. registering Module as a pytree node, does not appear to work because any custom error that we add inside Module's tree_flatten/tree_unflatten doesn't trigger at the point where the arg type check is performed.
@melissatan sorry I read the issue to quickly. I think your issue is that you need to register every Module subclass as a pytree rather than the Module class itself. So something like
class Module():
@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
jax.tree_util.register_pytree_node(cls, raise_error_fn, raise_some_error_fn)
Thanks @jheek, this worked!
At first I tried to pass lambdas as args to register_pytree_node() and was getting a bunch of errors.
Explicitly defining the tree_flatten and tree_unflatten got rid of the error.
Spoke too soon, when I implemented the above I found that the below doctest from module_lifecycle.rst fails:
class Partial(flax.struct.PyTreeNode):
fn: Callable = flax.struct.field(pytree_node=False)
args: Iterable[Any]
def __call__(self, *args, **kwargs):
return self.fn(*(tuple(self.args) + args), **kwargs)
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
dense = nn.Dense(x.shape[-1])
fn = lambda mdl, x: mdl(x) + 1
vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
return vmap_inner(self, x, Partial(fn, [dense]))
def inner(self, x, fn):
for i in range(3):
x = fn(x)
return x
.. testcode::
:hide:
x = jax.numpy.ones((3, 2))
mdl = Foo()
vars = mdl.init(random.PRNGKey(0), x)
assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2)
I assume that's because in this particular case, I should not be overwriting tree_util.py's tree_flatten() in the Module __init_subclass__() hook.
Is there some attribute I can use to detect, within __init_subclass__(), whether the Module's __call__() involves a PyTreeNode? I checked dir(mdl) but didn't spot an obvious one.
Here's the error trace:
Document: design_notes/module_lifecycle
---------------------------------------
**********************************************************************
File "design_notes/module_lifecycle.rst", line 691, in default
Failed example:
x = jax.numpy.ones((3, 2))
mdl = Foo()
vars = mdl.init(random.PRNGKey(0), x)
assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2)
Exception raised:
Traceback (most recent call last):
File "<doctest default[0]>", line 3, in <module>
vars = mdl.init(random.PRNGKey(0), x)
File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/runner/work/flax/flax/flax/linen/module.py", line 1236, in init
method=method, mutable=mutable, **kwargs)
File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/runner/work/flax/flax/flax/linen/module.py", line 1202, in init_with_output
{}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/runner/work/flax/flax/flax/linen/module.py", line 1169, in apply
)(variables, *args, **kwargs, rngs=rngs)
File "/home/runner/work/flax/flax/flax/core/scope.py", line 831, in wrapper
y = fn(root, *args, **kwargs)
File "/home/runner/work/flax/flax/flax/linen/module.py", line 1450, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/runner/work/flax/flax/flax/linen/transforms.py", line 1239, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/runner/work/flax/flax/flax/linen/module.py", line 350, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/runner/work/flax/flax/flax/linen/module.py", line 658, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "<doctest default[0]>", line 15, in __call__
File "/home/runner/work/flax/flax/flax/linen/transforms.py", line 362, in wrapped_fn
module_scopes, args, kwargs = get_module_scopes(self, args, kwargs)
File "/home/runner/work/flax/flax/flax/linen/transforms.py", line 138, in get_module_scopes
new_args, new_kwargs = jax.tree_map(get_arg_scope, (args, kwargs))
File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/tree_util.py", line 182, in tree_map
leaves, treedef = tree_flatten(tree, is_leaf)
File "/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/jax/_src/tree_util.py", line 54, in tree_flatten
return pytree.flatten(tree, is_leaf)
File "/home/runner/work/flax/flax/flax/linen/module.py", line 538, in tree_flatten
raise errors.JitPytreeError()
jax._src.traceback_util.UnfilteredStackTrace: flax.errors.JitPytreeError: A Flax Linen Module cannot be passed naively through a jitted transformation since it is not a pytree. To pass it as an arg, use static_argnames or static_argnums. See example in error docstring. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.JitPytreeError)
@jheek it seems like Melissa's change breaks your closure trick. Any ideas on how to resolve this?
This error comes from the fact taht get_module_scopes should do a tree_map with is_leaf=lambda x: isisntance(x, Module) such that Modules are treated as leaves instead of nodes. (And maybe the same should happen in a few other places)
Thanks for responding @jheek, I took a quick look at the code and unfortunately it's not obvious to me yet.
Should the is_leaf tree_map be added to this block in get_module_scopes (below) or somewhere else?:
elif isinstance(x, Module) and isinstance(x.scope, Scope) <-- and is_leaf(x) ?
x._try_setup(shallow=True) # pylint: disable=protected-access
scopes.append(x.scope)
attrs = { f.name: getattr(x, f.name) for f in dataclasses.fields(x) if f.name != 'parent' and f.init }
attrs = jax.tree_map(get_arg_scope, attrs)
return InstancePlaceholder(x.__class__, attrs, id(x)) return x
(And maybe the same should happen in a few other places)
Do we have enough test coverage to catch those other places, given that the only test failure for this PR was the abovementioned doctest from module_lifecycle.rst? If not, could you recommend a good way to find those other places?
Thanks!
The is_leaf is passed to the tree_map below so you get:
new_args, new_kwargs = jax.tree_map(get_arg_scope, (args, kwargs), is_leaf=lambda x: isinstance(x, nn.Module))
There are 3 calls to tree_map in get_module_scopes and I think they all need the is_leaf check
Was caught up with other responsibilities for some time, picking this back up now.
I've made the changes @jheek suggested, in the 3 tree_map() calls in get_module_scopes(). But this appears to be causing a test failure in several tests that use nn.jit() in linen_transforms_test, e.g.:
def test_compact_aliasing_collision(self):
class Foo(nn.Module):
m1: nn.Module
m2: nn.Module
@nn.compact
def __call__(self, x):
x = self.m2(self.m1(x))
return x
class Bar(nn.Module):
@nn.compact
def __call__(self, x):
dense = nn.Dense(2)
x = nn.jit(Foo)(dense, dense)(x) # <-- fails here, and outputs the jit pytree error that I'm defining in the PR #2270 .
return x
k = random.PRNGKey(0)
x = jnp.zeros((2, 2))
_ = Bar().init(k, x)
I tried specifying static_argnums for the nn.jit() call in the test above, but then got another failure about "Non-hashable static arguments are not supported.".
I'm not sufficiently familiar with the internals of linen_transforms. Is there another workaround that prevents the failure in module_lifecycle.rst (https://github.com/google/flax/issues/853#issuecomment-1113396067)?