diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Fitting ODE model with `diffeqsolve` is extremely slow using NUTS on GPU

Open kokbent opened this issue 2 years ago • 4 comments

So as the title says, I've been trying to fit my SIR ODE model using NUTS on GPU. However, the fit was extremely slow when compared to CPU. I'm using jax and numpyro to do the fitting. I ran this on Google colab:

CPU sample: 100%|██████████| 2000/2000 [02:15<00:00, 14.78it/s, 7 steps of size 3.16e-01. acc. prob=0.94]

GPU (had to interrupt because it's too slow) warmup: 1%| | 16/2000 [05:11<10:43:38, 19.47s/it, 1 steps of size 2.14e-04. acc. prob=0.58]

This is not an issue specific to diffrax, I had the same problem using odeint as my ODE solver too. I've searched through the internet, and seems like similar issue (but odeint) was reported in JAX: Gradients with odeint slow on GPU #5006. According to one of the reply: it seems like the tight loop structure in odeint is not XLA GPU friendly. Given that I have seen similar issue when using diffeqsolve, I guess that it also uses similar technique and suffer from similar issue? The question then is, is there any possible way to circumvent the problem within the diffrax package, perhaps another type of implementation?


Here's the code I use:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5

numpyro.set_platform("cpu")


def sir_ode(state, _, parameters):
    # Unpack state
    s, i, r = state
    beta, gamma = parameters
    population = s + i + r

    # Compute flows
    ds_to_i = beta * s * i / population
    di_to_r = gamma * i

    # Compute derivatives
    ds = -ds_to_i
    di = ds_to_i - di_to_r
    dr = di_to_r

    return (ds, di, dr)  # jnp.stack([ds, di, dr])


# Parameters
rng = np.random.default_rng(seed=867530)
beta = 1.5 / 4.5
gamma = 1.0 / 4.5
population = 10000
initial_infections = 1.0

initial_state = (
    population - initial_infections,  # s
    initial_infections,  # i
    0, # r
)

# Solve ODE
term = ODETerm(lambda t, state, parameters: sir_ode(state, t, parameters))
solver = Tsit5()
t0 = 0.0
t1 = 100.0
dt0 = 0.1
times = jnp.linspace(t0, t1, 101)
saveat = SaveAt(ts=times)



def des(initial_state, args):
    solution = diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        initial_state,
        args=args,
        saveat=saveat,
    )
    return solution


sol = des(initial_state, [beta, gamma])

# Generate incidence sample
rng = np.random.default_rng(seed=8675309)
incidence = -np.diff(sol.ys[0], axis=0)
incidence_sample = rng.poisson(incidence)


# Sampling model
def sir(times, incidence):
    # Parameters
    initial_infections = numpyro.sample("initial_infections", dist.Exponential(1.0))
    beta = numpyro.sample("beta", dist.Exponential(1.0))
    gamma = numpyro.sample("gamma", dist.Exponential(1.0))

    initial_state = (
        population - initial_infections,  # s
        initial_infections,  # i
        0,
    )  # r

    # Integrate the model
    solution = des(initial_state, [beta, gamma])
    model_incidence = -jnp.diff(solution.ys[0], axis=0)

    # Observed incidence
    numpyro.sample("incidence", dist.Poisson(model_incidence), obs=incidence)


rng_key = random.PRNGKey(8811)
nuts_kernel = NUTS(sir, dense_mass=True)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, times, incidence_sample)

kokbent avatar Dec 08 '23 19:12 kokbent

The first thing that jumps out is that you don't appear to be explicitly JIT'ing your computation. Diffrax already does this for you internally for the most part, but even so best practice is to put an equinox.filter_jit on des.

The second is that it looks like beta and gamma might be Python floats rather than JAX arrays, in which case I suspect things are recompiling every time. Make them NumPy or JAX arrays. (When using equinox.filter_jit, the rule is that things will recompile if a JAX/NumPy array changes shape or dtype, and if anything else changes in any way at all.)

patrick-kidger avatar Dec 08 '23 19:12 patrick-kidger

Hi Patrick, thanks for the response. I've jitted my des function as you suggested. For the beta and gamma, making a jax array in the first part of the code doesn't seem to have much effect (they are only used to generate a random sample). Within the sampling model sir(), it's handled by numpyro and i believe all the sampled parameters should be in some form of JAX traceable arrays. And the MCMC is still very slow. I probably should also put the issue to numpyro.

kokbent avatar Dec 11 '23 16:12 kokbent

You can double-check whether recompilation is happening with equinox.debug.assert_max_traces, by the way.

patrick-kidger avatar Dec 11 '23 16:12 patrick-kidger