dynamax
dynamax copied to clipboard
Marginal log likelihood of LGSSM parallel filter does not match recursive filter
The marginal log likelihoods from the recursive filter and parallel filter are different. The below example makes a small modification on the test case from parallel_inference_test.py by setting µ0[0] = 100.
Reproduce
from jax import numpy as jnp
from jax import random as jr
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_smoother, lgssm_filter, parallel_lgssm_filter, lgssm_joint_sample
def make_static_lgssm_params():
"""Create a static LGSSM with fixed parameters."""
dt = 0.1
F = jnp.eye(4) + dt * jnp.eye(4, k=2)
Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2],
[dt**2/2, dt]]),
jnp.eye(2))
H = jnp.eye(2, 4)
R = 0.5 ** 2 * jnp.eye(2)
μ0 = jnp.array([100.,0.,1.,-1.])
Σ0 = jnp.eye(4)
latent_dim = 4
observation_dim = 2
lgssm = LinearGaussianSSM(latent_dim, observation_dim)
params, _ = lgssm.initialize(jr.PRNGKey(0),
initial_mean=μ0,
initial_covariance= Σ0,
dynamics_weights=F,
dynamics_covariance=Q,
emission_weights=H,
emission_covariance=R)
return params, lgssm
num_timesteps = 50
key = jr.PRNGKey(1)
params, lgssm = make_static_lgssm_params()
_, emissions = lgssm_joint_sample(params, key, num_timesteps)
posterior = lgssm_filter(params, emissions)
parallel_posterior = parallel_lgssm_filter(params, emissions)
print(posterior.marginal_loglik)
print(parallel_posterior.marginal_loglik)
Output
-106.0021
-4143.4688
Have not done the math in detail, but I think it's due to this https://github.com/probml/dynamax/blob/800fae691edc7a372605a230d91344bd4420fd93/dynamax/linear_gaussian_ssm/parallel_inference.py#L186
I'm guessing it should be replaced with something like
logZ = -MVN(H @ m, H @ P @ H.T + R).log_prob(y)