dynamax icon indicating copy to clipboard operation
dynamax copied to clipboard

Marginal log likelihood of LGSSM parallel filter does not match recursive filter

Open gorold opened this issue 7 months ago • 0 comments

The marginal log likelihoods from the recursive filter and parallel filter are different. The below example makes a small modification on the test case from parallel_inference_test.py by setting µ0[0] = 100.

Reproduce

from jax import numpy as jnp
from jax import random as jr

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_smoother, lgssm_filter, parallel_lgssm_filter, lgssm_joint_sample


def make_static_lgssm_params():
    """Create a static LGSSM with fixed parameters."""
    dt = 0.1
    F = jnp.eye(4) + dt * jnp.eye(4, k=2)
    Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2],
                          [dt**2/2, dt]]),
                         jnp.eye(2))
                         
    H = jnp.eye(2, 4)
    R = 0.5 ** 2 * jnp.eye(2)
    μ0 = jnp.array([100.,0.,1.,-1.])
    Σ0 = jnp.eye(4)

    latent_dim = 4
    observation_dim = 2

    lgssm = LinearGaussianSSM(latent_dim, observation_dim)
    params, _ = lgssm.initialize(jr.PRNGKey(0),
                             initial_mean=μ0,
                             initial_covariance= Σ0,
                             dynamics_weights=F,
                             dynamics_covariance=Q,
                             emission_weights=H,
                             emission_covariance=R)
    return params, lgssm


num_timesteps = 50
key = jr.PRNGKey(1)
params, lgssm = make_static_lgssm_params()   
_, emissions = lgssm_joint_sample(params, key, num_timesteps)
posterior = lgssm_filter(params, emissions)
parallel_posterior = parallel_lgssm_filter(params, emissions)


print(posterior.marginal_loglik)
print(parallel_posterior.marginal_loglik)

Output

-106.0021
-4143.4688

Have not done the math in detail, but I think it's due to this https://github.com/probml/dynamax/blob/800fae691edc7a372605a230d91344bd4420fd93/dynamax/linear_gaussian_ssm/parallel_inference.py#L186

I'm guessing it should be replaced with something like

logZ = -MVN(H @ m, H @ P @ H.T + R).log_prob(y)

gorold avatar Jul 10 '25 07:07 gorold