Bug - Cannot access attribute "replace" for class "c123"
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (WSL 2 , Ubuntu 22):
- jax - 0.4.30
- flax - 0.10.0
- jaxlib - 0.4.29
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib: - Python version: 3.10.12
- GPU/TPU model and memory: CPU, 8gb RAM
- CUDA version (if applicable): -
Problem you have encountered:
Replace function in flax is shown as an error, as if it does not exist.
What you expected to happen:
The GUI shows no errors. It's quite annoying given that .replace is something to be called often.
Logs, error messages, etc:
Cannot access attribute "replace" for class "c123" Attribute "replace" is unknownPylancereportAttributeAccessIssue
Steps to reproduce:
Whenever possible, please provide a minimal example.
@dataclass
class c123:
variable1: int = 0
c123instance = c123(variable1=0).replace(variable1=1)
print(c123instance)
>>> c123(variable1=1)
Also, I would ask, how is .replace function working if it calls self in it's implementation? To my knowledge and to Jax docs, self should throw an error every time there is jit.
Hey, this is a known issue. To get around this inherit from PyTreeNode
from flax.struct import PyTreeNode
class c123(PyTreeNode):
variable1: int = 0
c123instance = c123(variable1=0).replace(variable1=1)
And can't you, as devs, automatically inherit when class is wrapped? Would make sense. Also, are there no performance drops with pytreenodes?
And can't you, as devs, automatically inherit when class is wrapped?
This is not possible.
Also, are there no performance drops with pytreenodes?
No performance drop.
--- was an unrelated question. ---
Hey, not sure if it helps but directly using dataclasses.replace would work:
from flax import struct
import dataclasses
@struct.dataclass
class c123:
variable1: int = 0
c123instance = dataclasses.replace(c123(variable1=0), variable1=1)