flax icon indicating copy to clipboard operation
flax copied to clipboard

lstm error

Open layssi opened this issue 1 year ago • 0 comments

import flax.linen as nn import jax, jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (2, 3)) layer = nn.LSTMCell(features=4) carry = layer.initialize_carry(jax.random.key(1), x.shape) variables = layer.init(jax.random.key(2), carry, x) new_carry, out = layer.apply(variables, carry, x)

Running the code gives this error. This code comes from the documentation

flax.errors.AssignSubModuleError: Submodule LSTMCell must be defined in setup() or in a method wrapped in @compact (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)

layssi avatar Jun 26 '24 15:06 layssi