-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Thanks for creating such a nice library.
I have experienced numerical instability in the Multivariate Gaussian Distributions. For better numerical stability, I wondered if it would be better replace e.g.,
| h_inv = jnp.linalg.inv(self.negative_half_precision) |
With Choleksy based operations e.g.,
import jax.scipy as jsp
JITTER = 1e-6
half_precision = self.negative_half_precision
lower_h = jsp.linalg.cholesky(half_precision + JITTER * jnp.eye(half_precision.shape[0]))
lower_h_inv = jsp.linalg.solve_triangular(lower_h, jnp.eye(half_precision.shape[0]), lower=True)
h_inv = -jsp.linalg.solve_triangular(lower_h, lower_h_inv, lower=False)We can also compute log determinant as ld = 2.0 * jnp.sum(jnp.log(jnp.diagonal(lower_h))). Downside though on using Cholesky decomposition in JAX, is it really wants you to be doing stuff in float64.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels