flax
flax copied to clipboard
Dropout seems not compatible with jax.jit
class DNN(nn.Module):
num_hidden_units1:int
num_hidden_units2:int
num_outputs:int
dropout_rate:float
@nn.compact
def __call__(self,x,training):
x=nn.Dense(self.num_hidden_units1)(x)
x=nn.relu(x)
x=nn.Dropout(rate=self.dropout_rate,deterministic=not training)(x)
x=nn.Dense(self.num_hidden_units2)(x)
x=nn.relu(x)
x=nn.Dropout(rate=self.dropout_rate,deterministic=not training)(x)
x=nn.Dense(self.num_outputs)(x)
return x
def mse(params,X,y,training):
def squared_error(X,y):
y_pred=model.apply(params,X,training,rngs={'dropout':jax.random.PRNGKey(114)})
diff=y_pred-y
return jnp.inner(diff,diff)
return jnp.mean(jax.vmap(squared_error)(X,y),axis=0)
@jax.jit
def train_setp(params,opt_state,X,y,training):
loss,grads=jax.value_and_grad(mse)(params,X,y,training)
updates,opt_state=optimizer.update(grads,opt_state)
params=optax.apply_updates(params,updates)
return params,opt_state,loss
@jax.jit
def fit(params,opt_state,X,y,training):
for i in tqdm(range(1000)):
params,opt_state,loss=train(params,opt_state,X,y,training)
if i%100==0:
print(loss)
return params
# Initialization
params=model.init(jax.random.PRNGKey(114),X,False)
learning_rate=0.1
optimizer=adam(learning_rate)
opt_state=optimizer.init(params)
train(params,opt_state,X,y,True)
The model could be successfully trained, when I add two dropout layers and don't use jax.jit. However, as long as I try to accelerate the training by jax.jit, it prompts the error
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[300], line 1
----> 1 train(params,opt_state,X,y,True)
[... skipping hidden 11 frame]
Cell In[264], line 4, in train(params, opt_state, X, y, training)
1 @jax.jit
2 def train(params,opt_state,X,y,training):
----> 4 loss,grads=jax.value_and_grad(mse)(params,X,y,training)
5 updates,opt_state=optimizer.update(grads,opt_state)
6 params=optax.apply_updates(params,updates)
[... skipping hidden 8 frame]
Cell In[296], line 6, in mse(params, X, y, training)
4 diff=y_pred-y
5 return jnp.inner(diff,diff)
----> 6 return jnp.mean(jax.vmap(squared_error)(X,y),axis=0)
[... skipping hidden 3 frame]
Cell In[296], line 3, in mse.<locals>.squared_error(X, y)
2 def squared_error(X,y):
----> 3 y_pred=model.apply(params,X,training,rngs={'dropout':jax.random.PRNGKey(114)})
4 diff=y_pred-y
5 return jnp.inner(diff,diff)
[... skipping hidden 6 frame]
Cell In[293], line 11, in DNN.__call__(self, x, training)
9 x=nn.Dense(self.num_hidden_units1)(x)
10 x=nn.relu(x)
---> 11 x=nn.Dropout(rate=self.dropout_rate,deterministic=not training)(x)
13 x=nn.Dense(self.num_hidden_units2)(x)
14 x=nn.relu(x)
[... skipping hidden 1 frame]
File /opt/conda/lib/python3.10/site-packages/jax/_src/core.py:1492, in concretization_function_error.<locals>.error(self, arg)
1491 def error(self, arg):
-> 1492 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function train at /tmp/ipykernel_34/1452843710.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument training.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Seems something wrong with the flag training. How could I solve this? Thx.
You have to mark training as a static argument when you jit your functions, so the compiler knows that you're ok with recompiling the function if its value changes. See: https://jax.readthedocs.io/en/latest/jit-compilation.html#marking-arguments-as-static
In short, change your @jax.jit decorators to @partial(jax.jit, static_argnames=['training']) should do the trick. I know it's a bit confusing because the flax dropout guide neglects to mention this.