flax icon indicating copy to clipboard operation
flax copied to clipboard

Problem training model with jax.jvp in forward pass

Open Thomas-Christie opened this issue 10 months ago • 3 comments

Hi - thanks for the library! I'm currently running into some issues when trying to train a "linearised" NN model. In my setup I'm taking a standard NN with weights $\theta^*$, let's denote $f(\cdot;\theta^*)$, and then forming a "linearised" version of it by taking a 1st-order Taylor expansion about $\theta^*$. Mathematically:

$f^\text{lin}_{\theta^*}(x;\theta) = f(x;\theta^*) + J_{\theta} f(x;\theta^*) (\theta - \theta^*) $

I then want to train this "linearised" model, keeping the original weights $\theta^*$ fixed and only updating weights $\theta$ (or alternatively updating the "weight update" $\tau = \theta - \theta^*$). For freezing the weights $\theta^*$ I have tried to adopt the approach mentioned in https://github.com/google/flax/issues/4167. I'm able to perform a forward pass with my model fine, but as soon as I try to train the model, I run into issues. An example is as follows:

# Original model
class RegressionModel(nnx.Module):

    def __init__(self, *, rngs: nnx.Rngs) -> None:
        self.linear1 = nnx.Linear(1, 20, rngs=rngs)
        self.linear2 = nnx.Linear(20, 20, rngs=rngs)
        self.linear3 = nnx.Linear(20, 1, rngs=rngs)

    def __call__(self, x):
        x = nnx.tanh(self.linear1(x))
        x = nnx.tanh(self.linear2(x))
        x = self.linear3(x)
        return x


class LinearisedModel(nnx.Module):

    def __init__(self, original_model: nnx.Module) -> None:
        self.graph_def, self.original_weights = nnx.split(original_model)
        self.weight_update = jax.tree.map(
            nnx.Param, jax.tree.map(jnp.zeros_like, self.original_weights)
        )  # Want to be trainable, corresponds to "tau" above
        self.original_weights = jax.tree.map(
            nnx.Param, self.original_weights
        )  # Want to be not trainable

    def __call__(self, x):
        def _model_fn(weights):
            return nnx.call((self.graph_def, weights))(x)[0]

        original_pred, upd = jax.jvp(
            _model_fn,
            (self.original_weights,),
            (self.weight_update,),
        )
        return original_pred + upd


original_model = RegressionModel(rngs=nnx.Rngs(42))
linearised_model = LinearisedModel(original_model)

trainable_params = nnx.All(nnx.Param, nnx.PathContains("weight_update"))
optimizer = nnx.Optimizer(
    linearised_model,
    tx=optax.adamw(3e-4),
    wrt=trainable_params,
)


def train_step(model, optimizer, x, y):
    def loss_fn(model):
        pred = model(x)
        return optax.squared_error(pred, y).mean()

    diff_state = nnx.DiffState(0, trainable_params)
    grads = nnx.grad(loss_fn, argnums=diff_state)(model)
    optimizer.update(grads)


x = jnp.ones((1, 1))
y = jnp.ones((1, 1))
train_step(linearised_model, optimizer, x, y)

The error I'm getting is:

TypeError: Argument 'Param(
  value=Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
    primal = Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0.]], dtype=float32)
    tangent = Traced<ShapedArray(float32[1,20])>with<JaxprTrace(level=1/0)> with
      pval = (ShapedArray(float32[1,20]), None)
      recipe = LambdaBinding()
)' of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.

and from the stack trace it seems to be stemming from the jax.jvp() call.

Can you see what I'm doing wrong? I'm happy to provide more details if necessary. A PyTorch version of what I'm trying to achieve can be found here. Thanks in advance!

Thomas-Christie avatar Mar 18 '25 23:03 Thomas-Christie

Hey! The issue is that jax.jvp doesn't handle nnx.Variable instances (like Param) correctly. To fix it call nnx.state before passing them, I tested this runs on your code:

original_pred, upd = jax.jvp(
  _model_fn,
  (nnx.state(self.original_weights),),
  (nnx.state(self.weight_update),),
)

I think we just need to add nnx.vjp to make this cleaner, I've seen a couple of users that need it already.

cgarciae avatar Mar 19 '25 17:03 cgarciae

Great, thanks, that worked! And yeah, I think having nnx.jvp would probably be a bit more intuitive - I actually searched to see if this existed whilst debugging.

Thomas-Christie avatar Mar 19 '25 22:03 Thomas-Christie

@cgarciae The same holds for vjp, I assume? Is there a reason why these functional transformations don't exist yet? Happy to contribute them if it's as straight forward as it seems.

marcelroed avatar Jun 17 '25 23:06 marcelroed

Hi all, I am migrating from dm-haiku to flax.nnx, I definitely want nnx.vjp and nnx.jvp, thank you in advance!

smao-astro avatar Jun 25 '25 17:06 smao-astro

With the latest Flax with nnx.Module as a pytree we may not need to have nnx.vjp and nnx.jvp and jax transformations would work. Here is the working code adapting the original code snippet to the latest version:

import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax


# Original model
class RegressionModel(nnx.Module):

    def __init__(self, *, rngs: nnx.Rngs) -> None:
        self.linear1 = nnx.Linear(1, 20, rngs=rngs)
        self.linear2 = nnx.Linear(20, 20, rngs=rngs)
        self.linear3 = nnx.Linear(20, 1, rngs=rngs)

    def __call__(self, x):
        x = nnx.tanh(self.linear1(x))
        x = nnx.tanh(self.linear2(x))
        x = self.linear3(x)
        return x


class LinearisedModel(nnx.Module, pytree=False):  
    # pytree arg is a new thing coming with v0.12 
    # if set to True we need to use nnx.data on self.weight_update and self.original_weights

    def __init__(self, original_model: nnx.Module) -> None:
        self.graph_def, original_weights = nnx.split(original_model)
        # original_weights is already a nnx.State with nnx.Param leaves

        self.weight_update = jax.tree.map(jnp.zeros_like, original_weights)  # Want to be trainable, corresponds to "tau" above
        # self.weight_update is a nnx.State with nnx.Param leaves

        self.original_weights = original_weights  # Want to be not trainable

    def __call__(self, x):
        def _model_fn(weights):
            return nnx.call((self.graph_def, weights))(x)[0]

        original_pred, upd = jax.jvp(
            _model_fn,
            (self.original_weights,),  # <--- Nothing to change here
            (self.weight_update,),    # <--- Nothing to change here
        )
        return original_pred + upd

# Alternatively, we can implement LinearisedModel without using nnx.split
# class LinearisedModel(nnx.Module, pytree=False):  
#     # pytree arg is a new thing coming with v0.12 
#     # if set to True we need to use nnx.data on self.weight_update and self.original_weights

#     def __init__(self, original_model: nnx.Module) -> None:
#         self.original_model = original_model  # Want to be not trainable
#         self.weight_update = jax.tree.map(lambda x: 0.1 * jnp.ones_like(x), original_model)  # Want to be trainable, corresponds to "tau" above
#         # self.weight_update is a nnx.State with nnx.Param leaves

#     def __call__(self, x):
#         def _model_fn(model):
#             return model(x)

#         original_pred, upd = jax.jvp(
#             _model_fn,
#             (self.original_model,),
#             (self.weight_update,),
#         )
#         return original_pred + upd

original_model = RegressionModel(rngs=nnx.Rngs(42))
linearised_model = LinearisedModel(original_model)

trainable_params = nnx.All(nnx.Param, nnx.PathContains("weight_update"))
optimizer = nnx.Optimizer(
    linearised_model,
    tx=optax.adamw(3e-4),
    wrt=trainable_params,
)


def train_step(model, optimizer, x, y):
    def loss_fn(model):
        pred = model(x)
        return optax.squared_error(pred, y).mean()

    diff_state = nnx.DiffState(0, trainable_params)
    grads = nnx.grad(loss_fn, argnums=diff_state)(model)
    optimizer.update(model, grads)


x = jnp.ones((1, 1))
y = jnp.ones((1, 1))
train_step(linearised_model, optimizer, x, y)

I think we can close this issue as resolved. @Thomas-Christie feel free to reopen if you think we need further discussion on this topic.

vfdev-5 avatar Sep 17 '25 15:09 vfdev-5