`jit_model_args` seems to not work with `chain_method="parallel"`
See the discrepancy below in the iteration durations between sequential and parallel chains.
import numpy as np
import numpyro as ny
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random
import jax
import pandas as pd
import arviz as az
az.style.use("arviz-darkgrid")
az.rcParams["plot.max_subplots"] = 1000
# pd.options.display.max_rows = 1000
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 2000)
ny.util.set_host_device_count(6)
def model(data):
width = data.shape[1]
length = data.shape[0]
corr = ny.sample("corr", dist.LKJ(width, concentration=1))
std_devs = ny.sample("std_devs", dist.HalfCauchy(jnp.ones(width)))
chol = ny.deterministic("chol", jnp.linalg.cholesky(jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))))
mu_reparam = ny.sample("mu_reparam", dist.MultivariateNormal(loc=jnp.zeros(width), covariance_matrix=jnp.identity(width)))
mu = ny.deterministic("mu", jnp.matmul(chol, mu_reparam))
with ny.plate("data", length):
returns = ny.sample("returns", dist.MultivariateNormal(loc=mu, scale_tril=chol), obs=data)
import time
data = np.random.random((1000,5))
mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4, progress_bar=False, jit_model_args=True, chain_method="sequential")
for _ in range(5):
t1 = time.time()
mcmc.run(random.PRNGKey(0), data=data, extra_fields=["diverging", "potential_energy"])
t2 = time.time()
print(t2 - t1)
print("====================")
mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4, progress_bar=False, jit_model_args=True, chain_method="parallel")
for _ in range(5):
t1 = time.time()
mcmc.run(random.PRNGKey(0), data=data, extra_fields=["diverging", "potential_energy"])
t2 = time.time()
print(t2 - t1)
20.672340869903564
8.105369567871094
7.848393440246582
8.228676319122314
8.964311599731445
====================
20.858465909957886
19.117286682128906
18.85404682159424
19.331879377365112
19.406771898269653
I think that currently, this is not supported. We should mention this in the docs. I guess you can try the pattern
def f(...):
mcmc = ...
mcmc.run(...)
return ...
jit_run = jax.jit(f)
samples1 = jit_run(...)
samples2 = jit_run(...)
to see if you can jit a parallel map and get any speed up.
This approach works for my use case. Thanks! Feel free to close.
I guess we can leave the issue open for a while. As I understand, we can jit+pmap here by passing model args to the tuplemap_args and specifying in_axes correctly. However, it seems that using jit+pmap is not efficient (from the warning when running it). There might be some workaround by changing the implementation of MCMC.run a bit. I'm not sure. I feel that if we do some sort of
partial_map_fn = partial(
self._single_chain_mcmc,
# args=args,
kwargs=kwargs,
collect_fields=collect_fields,
)
states, last_state = pmap(partial_map_fn, in_axes=(0, None))(map_args, args)
then pmap will be fast in the second run.