efax icon indicating copy to clipboard operation
efax copied to clipboard

Use Choleksy decomposition instead of `jnp.linalg.inv` in Multivariate Gaussian Distributions.

Open daniel-dodd opened this issue 2 years ago • 1 comments

Thanks for creating such a nice library.

I have experienced numerical instability in the Multivariate Gaussian Distributions. For better numerical stability, I wondered if it would be better replace e.g., https://github.com/NeilGirdhar/efax/blob/e6533df2d0aa9e775665ef8c5457d29277ef8f40/efax/_src/distributions/multivariate_normal/arbitrary.py#L48

With Choleksy based operations e.g.,

import jax.scipy as jsp

JITTER = 1e-6
half_precision = self.negative_half_precision
lower_h = jsp.linalg.cholesky(half_precision + JITTER  * jnp.eye(half_precision.shape[0]))
lower_h_inv = jsp.linalg.solve_triangular(lower_h,  jnp.eye(half_precision.shape[0]), lower=True)
h_inv = -jsp.linalg.solve_triangular(lower_h, lower_h_inv, lower=False)

We can also compute log determinant as ld = 2.0 * jnp.sum(jnp.log(jnp.diagonal(lower_h))). Downside though on using Cholesky decomposition in JAX, is it really wants you to be doing stuff in float64.

daniel-dodd avatar Apr 23 '24 10:04 daniel-dodd

This looks very interesting!

Does it produce worse results than before in float32 mode?

Are you able to demonstrate the improved precision either using a new test, or by lowering the tolerance on an existing test?

Also, it's a bit unfortunate, but solve_triangular is not part of the Array API (at least not yet, but feel free to propose it). In the long run, I'd like to support the Array API.

PS: Are you missing a minus sign where you set half_precision?

NeilGirdhar avatar Apr 23 '24 11:04 NeilGirdhar