Melissa Tan
Melissa Tan
@marcvanzee I could take this. To clarify before I start, is this what you had in mind: ``` @register_pytree_node_class class Module ``` and then define `tree_flatten` and `tree_unflatten` inside Module...
The naive approach described above doesn't seem to work, at least not out of the box. Upon investigation using a minimal test case ```py class Foo(nn.Module): @compact def __call__(self, x):...
Hi @jheek , yes I know that I cannot pass a model as a normal arg. I am able to create a test case using`static_argnames` or `static_argnums` that works fine...
Thanks @jheek, this worked! At first I tried to pass lambdas as args to `register_pytree_node()` and was getting a bunch of errors. Explicitly defining the tree_flatten and tree_unflatten got rid...
Spoke too soon, when I implemented the above I found that the below doctest from module_lifecycle.rst fails: ```py class Partial(flax.struct.PyTreeNode): fn: Callable = flax.struct.field(pytree_node=False) args: Iterable[Any] def __call__(self, *args, **kwargs):...
Thanks for responding @jheek, I took a quick look at the code and unfortunately it's not obvious to me yet. Should the `is_leaf` tree_map be added to this block in...
Was caught up with other responsibilities for some time, picking this back up now. I've made the changes @jheek suggested, in the 3 tree_map() calls in get_module_scopes(). But this appears...