scan over discrete latent variables causes tracer leak
Bug Description
This is a part of the issues reported in https://github.com/pyro-ppl/numpyro/issues/1981. Running the following test will raise an error/xfail.
Steps to Reproduce
JAX_CHECK_TRACER_LEAKS=1 pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke
Expected Behavior
The test should pass.
The reason seems to be caused by this line https://github.com/pyro-ppl/numpyro/blob/b7667062c067b87bb9450205a57227ae872fa91a/numpyro/contrib/funsor/discrete.py#L59
where the stateful adjoint tape is not compatible with jax scan.
Switching back to lazy interpretations seems to fix the leakage but it makes some tests failing.
- with funsor.adjoint.AdjointTape() as tape:
+ with funsor.interpretations.lazy:
with block(), enum(first_available_dim=first_available_dim):
log_prob, model_tr, log_measures = _enum_log_density(
model, args, kwargs, {}, sum_op, prod_op
)
with approx:
- approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
+ approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
I wondered what the status on this is - I'm trying to do SVI with discrete latents, and I think the errors I'm getting (tracer leaks that make it impossible to run) come from this bug. I guess I have 3 questions: Am I right to think that the fixes in https://github.com/pyro-ppl/numpyro/pull/2002 deal with this problem, or is it actually https://github.com/pyro-ppl/numpyro/issues/1999? Is the fork in a state where it's reasonable to install from and use it? And is there anything I could do to help you get the PR over the line?