flax icon indicating copy to clipboard operation
flax copied to clipboard

Unnecessary check for trace level in nnx.Variable.__setattr__

Open liudangyi opened this issue 6 months ago • 5 comments

In the development of a quantization library, we often need to collect some statistics of activations. Sometimes, the collection happens inside a custom_vjp function, as demonstrated below.

class QuantStats(nnx.Variable):
  # __setattr__ = object.__setattr__
  pass


class Model(nnx.Module):
  def __init__(self):
    self.stats = QuantStats({'absmax': jnp.zeros(())})
    self.linear = nnx.Linear(12, 10, rngs=nnx.Rngs(0))

  def __call__(self, x):
    @jax.custom_vjp
    def f(x):
      return fwd(x)[0]

    def fwd(x):
      self.stats.value = {'absmax': jnp.max(jnp.abs(x))}
      return x, ()

    def bwd(_, g):
      return g

    f.defvjp(fwd, bwd)
    return self.linear(f(x))


def loss_fn(model, x):
  out = model(x)
  return jnp.sum(jnp.abs(out))


model = Model()
loss_fn(model, jnp.full((1, 12), 42.))
print(model.stats['absmax'])
nnx.grad(loss_fn)(model, jnp.full((1, 12), 43.))
print(model.stats['absmax'])

Today, running the above code will raise an error like this

/tmp/ipython-input-90-2380234258.py in fwd(x)
     15 
     16     def fwd(x):
---> 17       self.stats.value = {'absmax': jnp.max(jnp.abs(x))}
     18       return x, ()
     19

.../flax/nnx/variablelib.py in __setattr__(self, name, value)
    274       name != 'value' or not self.mutable
    275     ):
--> 276       raise errors.TraceContextError(
    277         f'Cannot mutate {type(self).__name__} from a different trace level'
    278       )

TraceContextError: Cannot mutate QuantStats from a different trace level (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)

However, jax doesn't require a custom_vjp function to be pure. The above code will work if we uncomment the __setattr__ = object.__setattr__ line.

It seems that NNX is imposing an unnecessary check here.

liudangyi avatar Jul 25 '25 01:07 liudangyi

@cgarciae because I cannot assign bugs. cc @jshin1394 who found this bug.

liudangyi avatar Jul 25 '25 01:07 liudangyi

Hey @liudangyi ! Consider using nnx. custom_vjp and pass the model as an input.

    @nnx.custom_vjp
    def f(m, x):
      return fwd(x)[0]

    def fwd(m, x):
      m.stats.value = {'absmax': jnp.max(jnp.abs(x))}
      return x, ()

    def bwd(_, g):
      (m_g, _), y_g = g
      return m_g, y_g

cgarciae avatar Jul 25 '25 01:07 cgarciae

Unfortunately the actual f is a generic function and is not easily changeable. We're injecting self.stats.value = ... as a parameter to f.

liudangyi avatar Jul 29 '25 00:07 liudangyi

@cgarciae can we do anything here?

vfdev-5 avatar Oct 24 '25 08:10 vfdev-5

The check in Flax is intended, see https://github.com/jax-ml/jax/issues/32452.

liudangyi avatar Oct 24 '25 21:10 liudangyi