Skip to content

DADVI: initialise variational means from model initial point, not zeros#672

Open
jaj42 wants to merge 1 commit intopymc-devs:mainfrom
jaj42:jaj42
Open

DADVI: initialise variational means from model initial point, not zeros#672
jaj42 wants to merge 1 commit intopymc-devs:mainfrom
jaj42:jaj42

Conversation

@jaj42
Copy link
Copy Markdown

@jaj42 jaj42 commented Apr 13, 2026

fit_dadvi constructs its starting optimisation point x0 by setting all
variational means ({var}_mu) to zero. For models with non-zero prior
means in unconstrained space, this places the fixed DADVI draws far from the
region of positive likelihood. The result is that many of the 30 fixed draws
produce logp = -inf, the mean DADVI objective is +inf, and the optimiser
fails immediately with:

ValueError: array must not contain infs or NaNs

This error was hit when using fit_dadvi(gradient_backend = "jax"). PyTensor may handle the NaN values gracefully.

The same model samples successfully with NUTS.

Please note that I used AI to trace and patch this problem.

Cause

dadvi_initial_point = {
    f"{var_name}_mu": np.zeros_like(value).ravel()   # always zeros
    for var_name, value in initial_point_dict.items()
}

initial_point_dict already contains the prior means in unconstrained space.
np.zeros_like discards that information.

Minimal reproduction

import numpy as np
import pymc as pm
from pymc_extras.inference import fit_dadvi

with pm.Model() as model:
    mu = pm.LogNormal("mu", mu=np.log(4.5), sigma=0.5)
    obs = pm.Normal("obs", mu=mu, sigma=1.0, observed=np.array([4.0, 5.0, 4.5]))
    idata = fit_dadvi(gradient_backend="jax")
    # → ValueError: array must not contain infs or NaNs

With 30 N(0, 1) draws centred at zero rather than at log(4.5) ≈ 1.5, many draws
map to mu ≈ exp(-2) to exp(-3), producing near-zero model predictions and
logp = -inf. In the real-world case that exposed this bug (a 3-compartment PK
model with proportional error, 48 subjects, 107 unconstrained parameters),
15 out of 30 fixed draws gave logp = -inf at x0.

Notes

The DADVI paper (Giordano, Ingram & Broderick, 2024) provides no basis for
initialising at zero.

Algorithm 2 states:

procedure DADVI
t ← 0 / Fix N / Draw 𝒵 / while Not converged do …

The starting value of η = (μ, ξ) is left unspecified.

replace np.zeros_like(value) with np.asarray(value) so that the
variational means start at the model's prior means in unconstrained
space.
@jaj42 jaj42 marked this pull request as ready for review April 13, 2026 13:15
@jaj42
Copy link
Copy Markdown
Author

jaj42 commented Apr 13, 2026

Please note that CI has a linting failure which is not due to my diff but to previous code in the same module.

@ricardoV94
Copy link
Copy Markdown
Member

I also se a function above that can take initial point but gets None.

Anyway CC @martiningram

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.59%. Comparing base (2a299e2) to head (402fda1).

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main     #672       +/-   ##
===========================================
+ Coverage   66.55%   76.59%   +10.03%     
===========================================
  Files          73       73               
  Lines        8088     8088               
===========================================
+ Hits         5383     6195      +812     
+ Misses       2705     1893      -812     
Files with missing lines Coverage Δ
pymc_extras/inference/dadvi/dadvi.py 98.55% <ø> (ø)

... and 25 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants