flax icon indicating copy to clipboard operation
flax copied to clipboard

Using flax.linen.intercept_methods on an nnx module

Open gardberg opened this issue 10 months ago • 5 comments

Hi! I want to add a hook / interceptor to an nnx module (e.g. that prints the shapes of the input and output). This is straightforward to do for a module defined via the old flax.linen api (as shown here in the docs):

import flax.linen as nn
import jax.numpy as jnp

class Foo(nn.Module):
  def __call__(self, x):
    return x

def my_interceptor1(next_fun, args, kwargs, context):
  print('calling my_interceptor1')
  return next_fun(*args, **kwargs)

foo = Foo()
with nn.intercept_methods(my_interceptor1):
  _ = foo(jnp.ones([1]))
# >> calling my_interceptor1

However, this does not seem to work for an nnx module:

from flax import nnx # Using nnx instead of flax.linen
import jax.numpy as jnp

class Foo(nnx.Module): # changed to nnx.Module
  def __call__(self, x):
    return x

def my_interceptor1(next_fun, args, kwargs, context):
  print('calling my_interceptor1')
  return next_fun(*args, **kwargs)

foo = Foo()
with nn.intercept_methods(my_interceptor1):
  _ = foo(jnp.ones([1]))
# No output

I'm assuming I should not assume the linen interceptor to work for an nnx module.

My questions are:

  1. Is there a way to add an interceptor / hook to an nnx module? And if not:
  2. Are there any plans to add one by either:

a. Migrating the old implementation? b. Implementing a new one?

I would be happy to help out with either one if I could get some pointers :)

gardberg avatar Mar 29 '25 19:03 gardberg

Hi @gardberg, NNX avoids the interceptor pattern and instead suggests performing model surgery when new behavior is needed. Take a look at: https://flax.readthedocs.io/en/latest/guides/surgery.html

cgarciae avatar Apr 03 '25 14:04 cgarciae

That's reasonable, thanks a lot for pointing me to that! If I understood the surgery guide correctly the way to add a hook would be through monkey patching, e.g.

from flax import nnx
import jax

class TestModule(nnx.Module):
    def __init__(self):
        self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))

    def __call__(self, x):
        return self.linear(x)

model = TestModule()
x = jax.random.normal(jax.random.PRNGKey(0), (10,))

original_call = model.__call__

hooked_called = False
def hooked_call(self, *args, **kwargs):
    global hooked_called
    hooked_called = True
    print(f"Input to {self.__class__.__name__}:", args, kwargs)
    output = original_call(self, *args, **kwargs)
    print(f"Output from {self.__class__.__name__}:", output)
    return output

model.__call__ = hooked_call

y = model(x)
print(hooked_called)

Unfortunately here hooked_called is False and the prints in hooked_call are never called. My guess is that JAX uses a compiled version of the model __call__ at runtime, which gets based on the non-overriden __call__ method originally defined in the class. Is this correct?

I can get it to work if I subclass the original module and override the method, but I'm looking for a way to avoid subclassing.

I'm sure there's something I'm missing here, am I interpreting the surgery docs correctly, or is there a better way to achieve this kind of hook?

gardberg avatar Apr 03 '25 18:04 gardberg

It seems that, being the __call__ method a special one, python is not correctly modifying it (as i could figure it out, an instance of a class always looks up for the __call__ method defined in the class).

So, i would suggest two options: simply using another name for the method (for instance, call), an then referring to this method inside the __call__:

from flax import nnx
import jax

class TestModule(nnx.Module):
    def __init__(self):
        self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))

    def call(self, x):
        return self.linear(x)
    
    def __call__(self, x):
        return self.call(x)

model = TestModule()
x = jax.random.normal(jax.random.PRNGKey(0), (10,))

original_call = model.call

hooked_called = False
def hooked_call(self, *args, **kwargs):
    global hooked_called
    hooked_called = True
    print(f"Input to {self.__class__.__name__}:", args, kwargs)
    output = original_call(self, *args, **kwargs)
    print(f"Output from {self.__class__.__name__}:", output)
    return output

model.call = hooked_call

y = model(x)
print(hooked_called)

The other option i have found is to create a subclass using the original model's class, and overwriting the __call__ method. This may be more complex, but it's a must when we don't have control over the model definition.

from flax import nnx
import jax

class TestModule(nnx.Module):
    def __init__(self):
        self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))

    def call(self, x):
        return self.linear(x)
    
    def __call__(self, x):
        return self.call(x)

model = TestModule()
x = jax.random.normal(jax.random.PRNGKey(0), (10,))

original_class = model.__class__

hooked_called = False
class HookedClass(original_class):
    def __call__(self, *args, **kwargs):
        if self is model:  # Only intercept target instance
            global hooked_called
            hooked_called = True
            print(f"Input to {self.__class__.__name__}:", args, kwargs)
            output = super().__call__(*args, **kwargs)
            print(f"Output from {self.__class__.__name__}:", output)
            return output
        return super().__call__(*args, **kwargs)

model.__class__ = HookedClass

y = model(x)
print(hooked_called)

Both of this implementations should have hooked_call as True, and show prints inside the call.

Hope this helps you.

Pere-03 avatar Apr 14 '25 11:04 Pere-03

@gardberg @Pere-03 I created The Good and The Evil Way to Perform Monkey Patching with some ideas how to do this. Let me know what you think!

cgarciae avatar Apr 15 '25 17:04 cgarciae

@Pere-03 Thanks for the feedback and great ideas, appreciate it!

@cgarciae That looks cool, both methods seem to work great! Thanks a lot for writing that up! Second one looks a bit scary though hah

My original goal is to add hooks to / intercept the __call__ method of my model and all of its submodules, as I want to trace the shape of the input through the model and each submodule for debugging purposes. I'm not sure manually overriding the __call__ method for each class is the best way to do that.

I actually ended up adapting the linen intercept approach in a custom class which I then use to subclass each module I create. The only downside with this is that it doesn't intercept native nnx.Module objects, its a bit complicated, and doesn't align with the approach of nnx. Attaching my adaption code for context.

I think remember seeing some example of being able to trace the shape through a flax model via some function similar to nnx.eval_shape, or maybe that was just the top-level module's input and output shapes?

Click to toggle code
class Module(nnx.Module):
    """An extension of nnx.Module that supports method interception.
    
    This subclass intercepts the __call__ method at class initialization
    to allow for adding hooks via the intercept_methods context manager.
    """
    
    @classmethod
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        # Wrap the __call__ method if it exists in this subclass
        if '__call__' in cls.__dict__:
            original_call = cls.__call__
            
            @functools.wraps(original_call)
            def wrapped_call(self, *args, **kwargs):
                if not _global_interceptor_stack or len(_global_interceptor_stack) == 0:
                    return original_call(self, *args, **kwargs)
                
                return run_interceptors(original_call, self, *args, **kwargs)
                
            cls.__call__ = wrapped_call
    
    def __call__(self, *args, **kwargs):
        # This is the base implementation that will be inherited if not overridden
        return super().__call__(*args, **kwargs)


# Adapted from flax.linen.module
@dataclasses.dataclass(frozen=True)
class InterceptorContext:
  """Read only state showing the calling context for method interceptors.

  Attributes:
    module: The Module instance whose method is being called.
    method_name: The name of the method being called on the module.
    orig_method: The original method defined on the module. Calling it will
      short circuit all other interceptors.
  """

  module: 'Module'
  method_name: str
  orig_method: Callable[..., Any]


class ThreadLocalStack(threading.local):
  """Thread-local stack."""

  def __init__(self):
    self._storage = []

  def push(self, elem: Any) -> None:
    self._storage.append(elem)

  def pop(self) -> Any:
    return self._storage.pop()

  def __iter__(self) -> Iterator[Any]:
    return iter(reversed(self._storage))

  def __len__(self) -> int:
    return len(self._storage)

  def __repr__(self) -> str:
    return f'{self.__class__.__name__}({self._storage})'


Args = tuple[Any]
Kwargs = dict[str, Any]
NextGetter = Callable[..., Any]
Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any]
_global_interceptor_stack = ThreadLocalStack()

@contextlib.contextmanager
def intercept_methods(interceptor: Interceptor):
  """Context manager that registers an interceptor for module method '__call__'.

  This context manager will run the interceptor for all __call__ methods ran
  inside the context by any subclasses to 'Module'. This includes any submodules
  of the module, i.e. any other Module.__call__ methods that are called by the root
  Module.__call__ method.
  
  The interceptor can for example modify arguments, results, or skip calling the original method.
  
  Args:
    interceptor: A callable that takes (next_method, args, kwargs, context)
                and returns the result of the intercepted method.
  """
  _global_interceptor_stack.push(interceptor)
  try:
    yield
  finally:
    assert _global_interceptor_stack.pop() is interceptor


def run_interceptors(
  orig_method: Callable[..., Any],
  module: 'Module',
  *args,
  **kwargs,
) -> Any:
  """Runs method interceptors."""
  method_name = _get_fn_name(orig_method)
  # Create a bound method that will correctly receive 'self' as first argument
  fun = types.MethodType(orig_method, module)
  context = InterceptorContext(module, method_name, fun)

  def wrap_interceptor(interceptor, fun):
    """Wraps `fun` with `interceptor`."""

    @functools.wraps(fun)
    def wrapped(*args, **kwargs):
      return interceptor(fun, args, kwargs, context)

    return wrapped

  # Wraps interceptors around the original method. The innermost interceptor is
  # the last one added and directly wrapped around the original bound method.
  for interceptor in _global_interceptor_stack:
    fun = wrap_interceptor(interceptor, fun)
  return fun(*args, **kwargs)


def _get_fn_name(fn):
  if isinstance(fn, functools.partial):
    return _get_fn_name(fn.func)
  return getattr(fn, '__name__', 'unnamed_function')

gardberg avatar Apr 15 '25 19:04 gardberg