Problem training model with jax.jvp in forward pass
Hi - thanks for the library! I'm currently running into some issues when trying to train a "linearised" NN model. In my setup I'm taking a standard NN with weights $\theta^*$, let's denote $f(\cdot;\theta^*)$, and then forming a "linearised" version of it by taking a 1st-order Taylor expansion about $\theta^*$. Mathematically:
$f^\text{lin}_{\theta^*}(x;\theta) = f(x;\theta^*) + J_{\theta} f(x;\theta^*) (\theta - \theta^*) $
I then want to train this "linearised" model, keeping the original weights $\theta^*$ fixed and only updating weights $\theta$ (or alternatively updating the "weight update" $\tau = \theta - \theta^*$). For freezing the weights $\theta^*$ I have tried to adopt the approach mentioned in https://github.com/google/flax/issues/4167. I'm able to perform a forward pass with my model fine, but as soon as I try to train the model, I run into issues. An example is as follows:
# Original model
class RegressionModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs) -> None:
self.linear1 = nnx.Linear(1, 20, rngs=rngs)
self.linear2 = nnx.Linear(20, 20, rngs=rngs)
self.linear3 = nnx.Linear(20, 1, rngs=rngs)
def __call__(self, x):
x = nnx.tanh(self.linear1(x))
x = nnx.tanh(self.linear2(x))
x = self.linear3(x)
return x
class LinearisedModel(nnx.Module):
def __init__(self, original_model: nnx.Module) -> None:
self.graph_def, self.original_weights = nnx.split(original_model)
self.weight_update = jax.tree.map(
nnx.Param, jax.tree.map(jnp.zeros_like, self.original_weights)
) # Want to be trainable, corresponds to "tau" above
self.original_weights = jax.tree.map(
nnx.Param, self.original_weights
) # Want to be not trainable
def __call__(self, x):
def _model_fn(weights):
return nnx.call((self.graph_def, weights))(x)[0]
original_pred, upd = jax.jvp(
_model_fn,
(self.original_weights,),
(self.weight_update,),
)
return original_pred + upd
original_model = RegressionModel(rngs=nnx.Rngs(42))
linearised_model = LinearisedModel(original_model)
trainable_params = nnx.All(nnx.Param, nnx.PathContains("weight_update"))
optimizer = nnx.Optimizer(
linearised_model,
tx=optax.adamw(3e-4),
wrt=trainable_params,
)
def train_step(model, optimizer, x, y):
def loss_fn(model):
pred = model(x)
return optax.squared_error(pred, y).mean()
diff_state = nnx.DiffState(0, trainable_params)
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
optimizer.update(grads)
x = jnp.ones((1, 1))
y = jnp.ones((1, 1))
train_step(linearised_model, optimizer, x, y)
The error I'm getting is:
TypeError: Argument 'Param(
value=Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]], dtype=float32)
tangent = Traced<ShapedArray(float32[1,20])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[1,20]), None)
recipe = LambdaBinding()
)' of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.
and from the stack trace it seems to be stemming from the jax.jvp() call.
Can you see what I'm doing wrong? I'm happy to provide more details if necessary. A PyTorch version of what I'm trying to achieve can be found here. Thanks in advance!
Hey! The issue is that jax.jvp doesn't handle nnx.Variable instances (like Param) correctly. To fix it call nnx.state before passing them, I tested this runs on your code:
original_pred, upd = jax.jvp(
_model_fn,
(nnx.state(self.original_weights),),
(nnx.state(self.weight_update),),
)
I think we just need to add nnx.vjp to make this cleaner, I've seen a couple of users that need it already.
Great, thanks, that worked! And yeah, I think having nnx.jvp would probably be a bit more intuitive - I actually searched to see if this existed whilst debugging.
@cgarciae The same holds for vjp, I assume? Is there a reason why these functional transformations don't exist yet? Happy to contribute them if it's as straight forward as it seems.
Hi all, I am migrating from dm-haiku to flax.nnx, I definitely want nnx.vjp and nnx.jvp, thank you in advance!
With the latest Flax with nnx.Module as a pytree we may not need to have nnx.vjp and nnx.jvp and jax transformations would work. Here is the working code adapting the original code snippet to the latest version:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
# Original model
class RegressionModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs) -> None:
self.linear1 = nnx.Linear(1, 20, rngs=rngs)
self.linear2 = nnx.Linear(20, 20, rngs=rngs)
self.linear3 = nnx.Linear(20, 1, rngs=rngs)
def __call__(self, x):
x = nnx.tanh(self.linear1(x))
x = nnx.tanh(self.linear2(x))
x = self.linear3(x)
return x
class LinearisedModel(nnx.Module, pytree=False):
# pytree arg is a new thing coming with v0.12
# if set to True we need to use nnx.data on self.weight_update and self.original_weights
def __init__(self, original_model: nnx.Module) -> None:
self.graph_def, original_weights = nnx.split(original_model)
# original_weights is already a nnx.State with nnx.Param leaves
self.weight_update = jax.tree.map(jnp.zeros_like, original_weights) # Want to be trainable, corresponds to "tau" above
# self.weight_update is a nnx.State with nnx.Param leaves
self.original_weights = original_weights # Want to be not trainable
def __call__(self, x):
def _model_fn(weights):
return nnx.call((self.graph_def, weights))(x)[0]
original_pred, upd = jax.jvp(
_model_fn,
(self.original_weights,), # <--- Nothing to change here
(self.weight_update,), # <--- Nothing to change here
)
return original_pred + upd
# Alternatively, we can implement LinearisedModel without using nnx.split
# class LinearisedModel(nnx.Module, pytree=False):
# # pytree arg is a new thing coming with v0.12
# # if set to True we need to use nnx.data on self.weight_update and self.original_weights
# def __init__(self, original_model: nnx.Module) -> None:
# self.original_model = original_model # Want to be not trainable
# self.weight_update = jax.tree.map(lambda x: 0.1 * jnp.ones_like(x), original_model) # Want to be trainable, corresponds to "tau" above
# # self.weight_update is a nnx.State with nnx.Param leaves
# def __call__(self, x):
# def _model_fn(model):
# return model(x)
# original_pred, upd = jax.jvp(
# _model_fn,
# (self.original_model,),
# (self.weight_update,),
# )
# return original_pred + upd
original_model = RegressionModel(rngs=nnx.Rngs(42))
linearised_model = LinearisedModel(original_model)
trainable_params = nnx.All(nnx.Param, nnx.PathContains("weight_update"))
optimizer = nnx.Optimizer(
linearised_model,
tx=optax.adamw(3e-4),
wrt=trainable_params,
)
def train_step(model, optimizer, x, y):
def loss_fn(model):
pred = model(x)
return optax.squared_error(pred, y).mean()
diff_state = nnx.DiffState(0, trainable_params)
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
optimizer.update(model, grads)
x = jnp.ones((1, 1))
y = jnp.ones((1, 1))
train_step(linearised_model, optimizer, x, y)
I think we can close this issue as resolved. @Thomas-Christie feel free to reopen if you think we need further discussion on this topic.