numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

scan over discrete latent variables causes tracer leak

Open fehiepsi opened this issue 11 months ago • 2 comments

Bug Description

This is a part of the issues reported in https://github.com/pyro-ppl/numpyro/issues/1981. Running the following test will raise an error/xfail.

Steps to Reproduce

JAX_CHECK_TRACER_LEAKS=1 pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke

Expected Behavior

The test should pass.

fehiepsi avatar Mar 06 '25 21:03 fehiepsi

The reason seems to be caused by this line https://github.com/pyro-ppl/numpyro/blob/b7667062c067b87bb9450205a57227ae872fa91a/numpyro/contrib/funsor/discrete.py#L59

where the stateful adjoint tape is not compatible with jax scan.

Switching back to lazy interpretations seems to fix the leakage but it makes some tests failing.

-    with funsor.adjoint.AdjointTape() as tape:
+    with funsor.interpretations.lazy:
         with block(), enum(first_available_dim=first_available_dim):
             log_prob, model_tr, log_measures = _enum_log_density(
                 model, args, kwargs, {}, sum_op, prod_op
             )
 
     with approx:
-        approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
+        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

fehiepsi avatar Mar 07 '25 17:03 fehiepsi

I wondered what the status on this is - I'm trying to do SVI with discrete latents, and I think the errors I'm getting (tracer leaks that make it impossible to run) come from this bug. I guess I have 3 questions: Am I right to think that the fixes in https://github.com/pyro-ppl/numpyro/pull/2002 deal with this problem, or is it actually https://github.com/pyro-ppl/numpyro/issues/1999? Is the fork in a state where it's reasonable to install from and use it? And is there anything I could do to help you get the PR over the line?

pscicluna avatar Jun 10 '25 09:06 pscicluna