AttributeError: module 'flax' has no attribute 'nn'
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib: 0.7.3 , 0.4.14 , 0.4.14 - Python version: Python 3.9.12
- GPU/TPU model and memory: NVIDIA GeForce
- CUDA version (if applicable): 11.7
Problem you have encountered:
im working in mip-nerf form google and tried to train some dataset and become Error:
2023-08-27 13:09:05.370975: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Traceback (most recent call last):
File "/home/alhasan.ali/mip/train.py", line 31, in
What you expected to happen:
Logs, error messages, etc:
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
The module in Flax is called flax.linen - it's just that people often write import flax.linen as nn for shorthand.
The library you are using probably is using an old version of Flax, which still had flax.nn. We remove this and replaced it with flax.linen (the new iteration of our NN abstraction), and usually do import flax.linen as nn, as Ivy said.
Looking at the mlpnerf repo, it has flax>=0.2.2 in its requirements, but probably it isn't compatible with new versions. Could you try installing flax==0.2.2 and see if it works then?