jax
jax copied to clipboard
`ValueError` when attempting to AOT lower or remat functions with `jax_getattr`.
Description
A simple repro:
import jax.ad_checkpoint
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
from jax.experimental.attrs import jax_getattr, jax_setattr
class A:
...
a = A()
jax_setattr(a, 'x', 0)
def fn(c):
jax_getattr(a, 'x') # comment out to run successfully
return c
# ValueError: safe_zip() argument 2 is longer than argument 1
jax.jit(fn).lower(0.0)
# ValueError: too many values to unpack (expected 0)
fn = jax.ad_checkpoint.remat(fn, policy=jax_remat_policies.everything_saveable)
fn(0.0)
System info (python version, jaxlib version, accelerator, etc.)
python: 3.10.14
jax: 0.4.34
jaxlib: 0.4.34
accelerator: cpu