jax icon indicating copy to clipboard operation
jax copied to clipboard

`ValueError` when attempting to AOT lower or remat functions with `jax_getattr`.

Open markblee opened this issue 1 year ago • 0 comments

Description

A simple repro:

import jax.ad_checkpoint
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
from jax.experimental.attrs import jax_getattr, jax_setattr

class A:
    ...

a = A()
jax_setattr(a, 'x', 0)

def fn(c):
    jax_getattr(a, 'x')  # comment out to run successfully
    return c

# ValueError: safe_zip() argument 2 is longer than argument 1 
jax.jit(fn).lower(0.0)

# ValueError: too many values to unpack (expected 0)
fn = jax.ad_checkpoint.remat(fn, policy=jax_remat_policies.everything_saveable)
fn(0.0)

System info (python version, jaxlib version, accelerator, etc.)

python: 3.10.14
jax: 0.4.34
jaxlib: 0.4.34
accelerator: cpu

markblee avatar Oct 15 '24 18:10 markblee