Skip to content

Use Choleksy decomposition instead of jnp.linalg.inv in Multivariate Gaussian Distributions. #27

@daniel-dodd

Description

@daniel-dodd

Thanks for creating such a nice library.

I have experienced numerical instability in the Multivariate Gaussian Distributions. For better numerical stability, I wondered if it would be better replace e.g.,

h_inv = jnp.linalg.inv(self.negative_half_precision)

With Choleksy based operations e.g.,

import jax.scipy as jsp

JITTER = 1e-6
half_precision = self.negative_half_precision
lower_h = jsp.linalg.cholesky(half_precision + JITTER  * jnp.eye(half_precision.shape[0]))
lower_h_inv = jsp.linalg.solve_triangular(lower_h,  jnp.eye(half_precision.shape[0]), lower=True)
h_inv = -jsp.linalg.solve_triangular(lower_h, lower_h_inv, lower=False)

We can also compute log determinant as ld = 2.0 * jnp.sum(jnp.log(jnp.diagonal(lower_h))). Downside though on using Cholesky decomposition in JAX, is it really wants you to be doing stuff in float64.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions