bug in NeuTraReparam
Minimal example:
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Trace_ELBO, SVI
from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer.autoguide import AutoBNAFNormal
n = 100
p = 10 # n_dim x
q = 5 # n_dim y
k = min(3, p, q) # n_dim latent
X = dist.MultivariateNormal(jnp.zeros(p), jnp.eye(p, p)).sample(PRNGKey(0), (n,))
Y = dist.MultivariateNormal(jnp.zeros(q), jnp.eye(q, q)).sample(PRNGKey(1), (n,))
def model(X, Y=None):
with numpyro.plate('_k', k):
P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1))
with numpyro.plate('_q', q):
Q_cov = numpyro.sample('Q_cov', dist.InverseGamma(3, 1))
P_cov = P_cov * jnp.eye(k, k)
Q_cov = Q_cov * jnp.eye(q, q)
with numpyro.plate('p', p):
P = numpyro.sample('P', dist.MultivariateNormal(jnp.zeros(k), P_cov))
with numpyro.plate('k', k):
Q = numpyro.sample('Q', dist.MultivariateNormal(jnp.zeros(q), Q_cov))
with numpyro.plate('n', n):
Z = X @ P # low rank representation of X
Y_pred = Z @ Q # transform back into Y via Q
return numpyro.sample('Y', dist.MultivariateNormal(Y_pred, jnp.eye(q, q)), obs=Y)
# --- this works ---
mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50)
mcmc.run(jax.random.PRNGKey(2), X, Y)
# --- this fails ---
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8, 8])
svi = SVI(model, guide, numpyro.optim.Adam(0.003), Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(3), 5_000, X, Y)
neutra = NeuTraReparam(guide, svi_result.params)
mcmc = MCMC(NUTS(neutra.reparam(model)), num_warmup=1_000, num_samples=3_000)
mcmc.run(jax.random.PRNGKey(4), X, Y)
I'm not entirely sure what's going on here. The following model works with vanilla NUTS, but returns TypeError: mul got incompatible shapes for broadcasting: (3, 5), (5, 5) when trying to run NUTS after reparameterizing with NeuTraReparam.
If I remove the top two plates and replace the latents with the constants
P_cov = jnp.eye(k, k)
Q_cov = jnp.eye(q, q)
the code runs but I get the following warnings:
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_P_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_Q_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site 'Y'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
mcmc.run(jax.random.PRNGKey(4), X, Y)
Maybe it has something to do with having multiple plate names with the same dimension?
Thanks @amifalk! This is a bug because we allow plate to be applied to the unconstrained value: https://github.com/pyro-ppl/numpyro/blob/b16741cc163b1a3753a331e3200c64cced9eb804/numpyro/infer/reparam.py#L283-L286
A temporary fix is to remove plate for the first site
P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1).expand([k]).to_event())