flax
flax copied to clipboard
Improve error message when user mistakenly holds a jax Array in an nnx.Module
The line here raises, imo, a vague error message (at least I felt so as a beginner to Jax, Flax and nnx).
I was holding a jax Array (a random key) in an nnx.Module and that led to the error. tbh, I am not sure if this happens only in this case. Either way, I would be happy to help improve the error message if it helps.
Hi @RaghuSpaceRajan, feel free to send a PR!
Hi @cgarciae , thanks. I made it more verbose. Does it look okay. It's in PR #4492.