numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

bug in NeuTraReparam

Open amifalk opened this issue 2 years ago • 1 comments

Minimal example:

import jax
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Trace_ELBO, SVI
from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer.autoguide import AutoBNAFNormal

n = 100
p = 10 # n_dim x
q = 5 # n_dim y
k = min(3, p, q) # n_dim latent

X = dist.MultivariateNormal(jnp.zeros(p), jnp.eye(p, p)).sample(PRNGKey(0), (n,))
Y = dist.MultivariateNormal(jnp.zeros(q), jnp.eye(q, q)).sample(PRNGKey(1), (n,))

def model(X, Y=None):    
    with numpyro.plate('_k', k):
         P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1))
        
    with numpyro.plate('_q', q):
         Q_cov = numpyro.sample('Q_cov', dist.InverseGamma(3, 1))    
    
    P_cov = P_cov * jnp.eye(k, k)
    Q_cov = Q_cov * jnp.eye(q, q)

    with numpyro.plate('p', p):
        P = numpyro.sample('P', dist.MultivariateNormal(jnp.zeros(k), P_cov))
    
    with numpyro.plate('k', k):
        Q = numpyro.sample('Q', dist.MultivariateNormal(jnp.zeros(q), Q_cov))
        
    with numpyro.plate('n', n):
        Z = X @ P # low rank representation of X
        Y_pred = Z @ Q # transform back into Y via Q

        return numpyro.sample('Y', dist.MultivariateNormal(Y_pred, jnp.eye(q, q)), obs=Y)

#  --- this works ---
mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50)
mcmc.run(jax.random.PRNGKey(2), X, Y) 

# --- this fails ---
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8, 8])
svi = SVI(model, guide, numpyro.optim.Adam(0.003), Trace_ELBO())

svi_result = svi.run(jax.random.PRNGKey(3), 5_000, X, Y)
neutra = NeuTraReparam(guide, svi_result.params)

mcmc = MCMC(NUTS(neutra.reparam(model)), num_warmup=1_000, num_samples=3_000)
mcmc.run(jax.random.PRNGKey(4), X, Y)

I'm not entirely sure what's going on here. The following model works with vanilla NUTS, but returns TypeError: mul got incompatible shapes for broadcasting: (3, 5), (5, 5) when trying to run NUTS after reparameterizing with NeuTraReparam.

If I remove the top two plates and replace the latents with the constants

P_cov =  jnp.eye(k, k)
Q_cov = jnp.eye(q, q)

the code runs but I get the following warnings:

<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_P_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site '_Q_log_prob'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)
<ipython-input-17-7e6df362d6ff>:54: UserWarning: Missing a plate statement for batch dimension -2 at site 'Y'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(jax.random.PRNGKey(4), X, Y)

Maybe it has something to do with having multiple plate names with the same dimension?

amifalk avatar Dec 07 '23 14:12 amifalk

Thanks @amifalk! This is a bug because we allow plate to be applied to the unconstrained value: https://github.com/pyro-ppl/numpyro/blob/b16741cc163b1a3753a331e3200c64cced9eb804/numpyro/infer/reparam.py#L283-L286

A temporary fix is to remove plate for the first site

P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1).expand([k]).to_event())

fehiepsi avatar Dec 07 '23 18:12 fehiepsi