Use Choleksy decomposition instead of `jnp.linalg.inv` in Multivariate Gaussian Distributions.
Thanks for creating such a nice library.
I have experienced numerical instability in the Multivariate Gaussian Distributions. For better numerical stability, I wondered if it would be better replace e.g., https://github.com/NeilGirdhar/efax/blob/e6533df2d0aa9e775665ef8c5457d29277ef8f40/efax/_src/distributions/multivariate_normal/arbitrary.py#L48
With Choleksy based operations e.g.,
import jax.scipy as jsp
JITTER = 1e-6
half_precision = self.negative_half_precision
lower_h = jsp.linalg.cholesky(half_precision + JITTER * jnp.eye(half_precision.shape[0]))
lower_h_inv = jsp.linalg.solve_triangular(lower_h, jnp.eye(half_precision.shape[0]), lower=True)
h_inv = -jsp.linalg.solve_triangular(lower_h, lower_h_inv, lower=False)
We can also compute log determinant as ld = 2.0 * jnp.sum(jnp.log(jnp.diagonal(lower_h))). Downside though on using Cholesky decomposition in JAX, is it really wants you to be doing stuff in float64.
This looks very interesting!
Does it produce worse results than before in float32 mode?
Are you able to demonstrate the improved precision either using a new test, or by lowering the tolerance on an existing test?
Also, it's a bit unfortunate, but solve_triangular is not part of the Array API (at least not yet, but feel free to propose it). In the long run, I'd like to support the Array API.
PS: Are you missing a minus sign where you set half_precision?