numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

`jit_model_args` seems to not work with `chain_method="parallel"`

Open colehaus opened this issue 4 years ago • 3 comments

See the discrepancy below in the iteration durations between sequential and parallel chains.

import numpy as np

import numpyro as ny
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random
import jax

import pandas as pd
import arviz as az

az.style.use("arviz-darkgrid")

az.rcParams["plot.max_subplots"] = 1000

# pd.options.display.max_rows = 1000
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 2000)

ny.util.set_host_device_count(6)
def model(data):
    width = data.shape[1]
    length = data.shape[0]

    corr = ny.sample("corr", dist.LKJ(width, concentration=1))
    std_devs = ny.sample("std_devs", dist.HalfCauchy(jnp.ones(width)))
    chol = ny.deterministic("chol", jnp.linalg.cholesky(jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))))
    
    mu_reparam = ny.sample("mu_reparam", dist.MultivariateNormal(loc=jnp.zeros(width), covariance_matrix=jnp.identity(width)))
    mu = ny.deterministic("mu", jnp.matmul(chol,  mu_reparam))
    
    with ny.plate("data", length):
        returns = ny.sample("returns", dist.MultivariateNormal(loc=mu, scale_tril=chol), obs=data)
import time

data = np.random.random((1000,5))

mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4, progress_bar=False, jit_model_args=True, chain_method="sequential")

for _ in range(5):
    t1 = time.time()
    mcmc.run(random.PRNGKey(0), data=data, extra_fields=["diverging", "potential_energy"])
    t2 = time.time()
    print(t2 - t1)
    
print("====================")
    
mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=500, num_chains=4, progress_bar=False, jit_model_args=True, chain_method="parallel")    
for _ in range(5):
    t1 = time.time()
    mcmc.run(random.PRNGKey(0), data=data, extra_fields=["diverging", "potential_energy"])
    t2 = time.time()
    print(t2 - t1)
20.672340869903564
8.105369567871094
7.848393440246582
8.228676319122314
8.964311599731445
====================
20.858465909957886
19.117286682128906
18.85404682159424
19.331879377365112
19.406771898269653

colehaus avatar Aug 25 '21 01:08 colehaus

I think that currently, this is not supported. We should mention this in the docs. I guess you can try the pattern

def f(...):
    mcmc = ...
    mcmc.run(...)
    return ...

jit_run = jax.jit(f)
samples1 = jit_run(...)
samples2 = jit_run(...)

to see if you can jit a parallel map and get any speed up.

fehiepsi avatar Aug 25 '21 02:08 fehiepsi

This approach works for my use case. Thanks! Feel free to close.

colehaus avatar Aug 26 '21 03:08 colehaus

I guess we can leave the issue open for a while. As I understand, we can jit+pmap here by passing model args to the tuplemap_args and specifying in_axes correctly. However, it seems that using jit+pmap is not efficient (from the warning when running it). There might be some workaround by changing the implementation of MCMC.run a bit. I'm not sure. I feel that if we do some sort of

        partial_map_fn = partial(
            self._single_chain_mcmc,
            # args=args,
            kwargs=kwargs,
            collect_fields=collect_fields,
        )
        states, last_state = pmap(partial_map_fn, in_axes=(0, None))(map_args, args)

then pmap will be fast in the second run.

fehiepsi avatar Aug 29 '21 03:08 fehiepsi