Using flax.linen.intercept_methods on an nnx module
Hi! I want to add a hook / interceptor to an nnx module (e.g. that prints the shapes of the input and output). This is straightforward to do for a module defined via the old flax.linen api (as shown here in the docs):
import flax.linen as nn
import jax.numpy as jnp
class Foo(nn.Module):
def __call__(self, x):
return x
def my_interceptor1(next_fun, args, kwargs, context):
print('calling my_interceptor1')
return next_fun(*args, **kwargs)
foo = Foo()
with nn.intercept_methods(my_interceptor1):
_ = foo(jnp.ones([1]))
# >> calling my_interceptor1
However, this does not seem to work for an nnx module:
from flax import nnx # Using nnx instead of flax.linen
import jax.numpy as jnp
class Foo(nnx.Module): # changed to nnx.Module
def __call__(self, x):
return x
def my_interceptor1(next_fun, args, kwargs, context):
print('calling my_interceptor1')
return next_fun(*args, **kwargs)
foo = Foo()
with nn.intercept_methods(my_interceptor1):
_ = foo(jnp.ones([1]))
# No output
I'm assuming I should not assume the linen interceptor to work for an nnx module.
My questions are:
- Is there a way to add an interceptor / hook to an nnx module? And if not:
- Are there any plans to add one by either:
a. Migrating the old implementation? b. Implementing a new one?
I would be happy to help out with either one if I could get some pointers :)
Hi @gardberg, NNX avoids the interceptor pattern and instead suggests performing model surgery when new behavior is needed. Take a look at: https://flax.readthedocs.io/en/latest/guides/surgery.html
That's reasonable, thanks a lot for pointing me to that! If I understood the surgery guide correctly the way to add a hook would be through monkey patching, e.g.
from flax import nnx
import jax
class TestModule(nnx.Module):
def __init__(self):
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
def __call__(self, x):
return self.linear(x)
model = TestModule()
x = jax.random.normal(jax.random.PRNGKey(0), (10,))
original_call = model.__call__
hooked_called = False
def hooked_call(self, *args, **kwargs):
global hooked_called
hooked_called = True
print(f"Input to {self.__class__.__name__}:", args, kwargs)
output = original_call(self, *args, **kwargs)
print(f"Output from {self.__class__.__name__}:", output)
return output
model.__call__ = hooked_call
y = model(x)
print(hooked_called)
Unfortunately here hooked_called is False and the prints in hooked_call are never called. My guess is that JAX uses a compiled version of the model __call__ at runtime, which gets based on the non-overriden __call__ method originally defined in the class. Is this correct?
I can get it to work if I subclass the original module and override the method, but I'm looking for a way to avoid subclassing.
I'm sure there's something I'm missing here, am I interpreting the surgery docs correctly, or is there a better way to achieve this kind of hook?
It seems that, being the __call__ method a special one, python is not correctly modifying it (as i could figure it out, an instance of a class always looks up for the __call__ method defined in the class).
So, i would suggest two options: simply using another name for the method (for instance, call), an then referring to this method inside the __call__:
from flax import nnx
import jax
class TestModule(nnx.Module):
def __init__(self):
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
def call(self, x):
return self.linear(x)
def __call__(self, x):
return self.call(x)
model = TestModule()
x = jax.random.normal(jax.random.PRNGKey(0), (10,))
original_call = model.call
hooked_called = False
def hooked_call(self, *args, **kwargs):
global hooked_called
hooked_called = True
print(f"Input to {self.__class__.__name__}:", args, kwargs)
output = original_call(self, *args, **kwargs)
print(f"Output from {self.__class__.__name__}:", output)
return output
model.call = hooked_call
y = model(x)
print(hooked_called)
The other option i have found is to create a subclass using the original model's class, and overwriting the __call__ method. This may be more complex, but it's a must when we don't have control over the model definition.
from flax import nnx
import jax
class TestModule(nnx.Module):
def __init__(self):
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
def call(self, x):
return self.linear(x)
def __call__(self, x):
return self.call(x)
model = TestModule()
x = jax.random.normal(jax.random.PRNGKey(0), (10,))
original_class = model.__class__
hooked_called = False
class HookedClass(original_class):
def __call__(self, *args, **kwargs):
if self is model: # Only intercept target instance
global hooked_called
hooked_called = True
print(f"Input to {self.__class__.__name__}:", args, kwargs)
output = super().__call__(*args, **kwargs)
print(f"Output from {self.__class__.__name__}:", output)
return output
return super().__call__(*args, **kwargs)
model.__class__ = HookedClass
y = model(x)
print(hooked_called)
Both of this implementations should have hooked_call as True, and show prints inside the call.
Hope this helps you.
@gardberg @Pere-03 I created The Good and The Evil Way to Perform Monkey Patching with some ideas how to do this. Let me know what you think!
@Pere-03 Thanks for the feedback and great ideas, appreciate it!
@cgarciae That looks cool, both methods seem to work great! Thanks a lot for writing that up! Second one looks a bit scary though hah
My original goal is to add hooks to / intercept the __call__ method of my model and all of its submodules, as I want to trace the shape of the input through the model and each submodule for debugging purposes. I'm not sure manually overriding the __call__ method for each class is the best way to do that.
I actually ended up adapting the linen intercept approach in a custom class which I then use to subclass each module I create. The only downside with this is that it doesn't intercept native nnx.Module objects, its a bit complicated, and doesn't align with the approach of nnx. Attaching my adaption code for context.
I think remember seeing some example of being able to trace the shape through a flax model via some function similar to nnx.eval_shape, or maybe that was just the top-level module's input and output shapes?
Click to toggle code
class Module(nnx.Module):
"""An extension of nnx.Module that supports method interception.
This subclass intercepts the __call__ method at class initialization
to allow for adding hooks via the intercept_methods context manager.
"""
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Wrap the __call__ method if it exists in this subclass
if '__call__' in cls.__dict__:
original_call = cls.__call__
@functools.wraps(original_call)
def wrapped_call(self, *args, **kwargs):
if not _global_interceptor_stack or len(_global_interceptor_stack) == 0:
return original_call(self, *args, **kwargs)
return run_interceptors(original_call, self, *args, **kwargs)
cls.__call__ = wrapped_call
def __call__(self, *args, **kwargs):
# This is the base implementation that will be inherited if not overridden
return super().__call__(*args, **kwargs)
# Adapted from flax.linen.module
@dataclasses.dataclass(frozen=True)
class InterceptorContext:
"""Read only state showing the calling context for method interceptors.
Attributes:
module: The Module instance whose method is being called.
method_name: The name of the method being called on the module.
orig_method: The original method defined on the module. Calling it will
short circuit all other interceptors.
"""
module: 'Module'
method_name: str
orig_method: Callable[..., Any]
class ThreadLocalStack(threading.local):
"""Thread-local stack."""
def __init__(self):
self._storage = []
def push(self, elem: Any) -> None:
self._storage.append(elem)
def pop(self) -> Any:
return self._storage.pop()
def __iter__(self) -> Iterator[Any]:
return iter(reversed(self._storage))
def __len__(self) -> int:
return len(self._storage)
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self._storage})'
Args = tuple[Any]
Kwargs = dict[str, Any]
NextGetter = Callable[..., Any]
Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any]
_global_interceptor_stack = ThreadLocalStack()
@contextlib.contextmanager
def intercept_methods(interceptor: Interceptor):
"""Context manager that registers an interceptor for module method '__call__'.
This context manager will run the interceptor for all __call__ methods ran
inside the context by any subclasses to 'Module'. This includes any submodules
of the module, i.e. any other Module.__call__ methods that are called by the root
Module.__call__ method.
The interceptor can for example modify arguments, results, or skip calling the original method.
Args:
interceptor: A callable that takes (next_method, args, kwargs, context)
and returns the result of the intercepted method.
"""
_global_interceptor_stack.push(interceptor)
try:
yield
finally:
assert _global_interceptor_stack.pop() is interceptor
def run_interceptors(
orig_method: Callable[..., Any],
module: 'Module',
*args,
**kwargs,
) -> Any:
"""Runs method interceptors."""
method_name = _get_fn_name(orig_method)
# Create a bound method that will correctly receive 'self' as first argument
fun = types.MethodType(orig_method, module)
context = InterceptorContext(module, method_name, fun)
def wrap_interceptor(interceptor, fun):
"""Wraps `fun` with `interceptor`."""
@functools.wraps(fun)
def wrapped(*args, **kwargs):
return interceptor(fun, args, kwargs, context)
return wrapped
# Wraps interceptors around the original method. The innermost interceptor is
# the last one added and directly wrapped around the original bound method.
for interceptor in _global_interceptor_stack:
fun = wrap_interceptor(interceptor, fun)
return fun(*args, **kwargs)
def _get_fn_name(fn):
if isinstance(fn, functools.partial):
return _get_fn_name(fn.func)
return getattr(fn, '__name__', 'unnamed_function')