Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ you should jump to {ref}`array_stats_api` and read forward.
:toctree: generated/

arviz_stats.bfmi
arviz_stats.diagnose
arviz_stats.ess
arviz_stats.loo_pit
arviz_stats.mcse
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from arviz_stats.psense import psense, psense_summary
from arviz_stats.metrics import bayesian_r2, kl_divergence, metrics, residual_r2, wasserstein
from arviz_stats.sampling_diagnostics import bfmi, ess, mcse, rhat, rhat_nested
from arviz_stats.sampling_diagnostics import bfmi, ess, mcse, rhat, rhat_nested, diagnose
from arviz_stats.summary import summary, ci_in_rope, mean, median, mode
from arviz_stats.manipulation import thin, weight_predictions
from arviz_stats.bayes_factor import bayes_factor
Expand Down
325 changes: 324 additions & 1 deletion src/arviz_stats/sampling_diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Functions for sampling diagnostics."""

import sys

import numpy as np
import xarray as xr
from arviz_base import convert_to_dataset
from arviz_base import convert_to_dataset, convert_to_datatree, rcParams

from arviz_stats.utils import _apply_multi_input_function, get_array_function
from arviz_stats.validate import validate_dims
Expand Down Expand Up @@ -615,3 +617,324 @@ def bfmi(
coords=coords,
**kwargs,
)


def diagnose(
data,
*,
var_names=None,
filter_vars=None,
coords=None,
sample_dims=None,
group="posterior",
rhat_max=1.01,
ess_min_ratio=0.001,
bfmi_threshold=0.3,
show_diagnostics=True,
return_diagnostics=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could take a page from the plots side here and allow a stats dictionary argument. That being said, here I would only allow providing pre-computed elements to keep things simple. For example:

rhat = az.rhat_nested(idata, superchains=...)
az.diagnose(idata, stats={"rhat": rhat})

In that particular case it would make it easier to take advantage of diagnose even if wanting to use rhat_nested instead of regular rhat. They could also change the probability at which to compute ess_tail

):
"""Run comprehensive diagnostic checks for MCMC sampling.

This function performs diagnostic checks on MCMC samples similar to CmdStan's diagnose
utility. It checks for:

- Divergent transitions
- Maximum tree depth saturation
- Low E-BFMI (Energy Bayesian Fraction of Missing Information)
- Low effective sample size (ESS)
- High R-hat values

See [1]_ and [2]_ for more details. You can also check https://arviz-devs.github.io/EABM/Chapters/MCMC_diagnostics.html
for a more practical overview.

Parameters
----------
data : DataTree, Dataset, or InferenceData-like
Input data. To be able to compute all diagnostics, the data should contain MCMC
posterior samples (see ``group`` argument) and sampler statistics (in "sample_stats" group).
var_names : str or list of str, optional
Names of variables to check for R-hat and ESS diagnostics.
If None, checks all variables.
filter_vars : {None, "like", "regex"}, default None
How to filter variable names. See :func:`filter_vars` for details.
coords : dict, optional
Coordinates to select a subset of the data.
sample_dims : iterable of hashable, optional
Dimensions to be considered sample dimensions.
Default from ``rcParams["data.sample_dims"]``.
group : str, default "posterior"
Group to check for convergence diagnostics (R-hat, ESS).
rhat_max : float, default 1.01
Maximum acceptable R-hat value. Parameters with R-hat > rhat_max
will be flagged.
ess_min_ratio : float, default 0.001
Minimum acceptable ratio of ESS to total samples. Parameters with
ESS/N < ess_min_ratio will be flagged.
A flag is also emitted if ESS is lower than 100 * number of chains.
bfmi_threshold : float, default 0.3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have this as an rcParam I think now

Minimum acceptable E-BFMI value. Values below this threshold indicate
potential issues with the sampler's exploration.
show_diagnostics : bool, default True
If True, print diagnostic messages to stdout. If False, return results silently.
return_diagnostics : bool, default False
If True, return a dictionary with detailed diagnostic results in addition
to the boolean has_errors flag.

Returns
-------
has_errors : bool
True if any diagnostic checks failed, False otherwise.
diagnostics : dict, optional
Only returned if return_diagnostics=True.

- "divergent": dict with keys "n_divergent", "pct", "total_samples"
- "treedepth": dict with keys "n_max", "pct", "total_samples"
- "bfmi": dict with keys "bfmi_values", "failed_chains", "threshold"
- "ess": dict with keys "bad_params", "ess_values", "threshold_ratio", "total_samples"
- "rhat": dict with keys "bad_params", "rhat_values", "threshold"

Examples
--------
Get diagnostics printted to stdout:

.. ipython::

In [1]: import arviz_stats as azs
...: from arviz_base import load_arviz_data
...: data = load_arviz_data('centered_eight')
...: azs.diagnose(data)

Get detailed diagnostic information without printing messages:

.. ipython::

In [1]: _, diagnostics = azs.diagnose(data, return_diagnostics=True, show_diagnostics=False)
...: diagnostics

See Also
--------
rhat : Compute R-hat convergence diagnostic
ess : Compute effective sample size
bfmi : Compute Bayesian fraction of missing information
summary : Create a data frame with summary statistics, including diagnostics.

References
----------
.. [1] Vehtari et al. *Rank-normalization, folding, and localization: An improved Rhat for
assessing convergence of MCMC*. Bayesian Analysis. 16(2) (2021)
https://doi.org/10.1214/20-BA1221. arXiv preprint https://arxiv.org/abs/1903.08008
.. [2] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in
Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]

dt = convert_to_datatree(data)

has_errors = False
diagnostics_results = {}
messages = []

sample_stats = dt.get("sample_stats")
if sample_stats is None:
messages.append("No sample_stats group found. Skipping sampler-specific diagnostics.")
sample_stats_available = False
else:
sample_stats_available = True

posterior = dt[group]

total_samples = np.prod([posterior.sizes[dim] for dim in sample_dims if dim in posterior.sizes])

# Check divergences
if sample_stats_available and "diverging" in sample_stats:
diverging = sample_stats["diverging"]
n_divergent = int(diverging.sum().values)

diagnostics_results["divergent"] = {
"n_divergent": n_divergent,
"pct": 100 * n_divergent / total_samples,
"total_samples": total_samples,
}

messages.append("Divergences")
if n_divergent > 0:
has_errors = True
pct = diagnostics_results["divergent"]["pct"]
messages.append(
f"{n_divergent} of {total_samples} ({pct:.2f}%) transitions ended with a "
"divergence.\n"
"These divergent transitions indicate that HMC is not fully able to explore "
"the posterior distribution.\n"
"Try increasing adapt delta closer to 1.\n"
"If this doesn't remove all divergences, try to reparameterize the model."
)
else:
messages.append("No divergent transitions found.")

# Check tree depth
if sample_stats_available and "reached_max_treedepth" in sample_stats:
reached_max_treedepth = sample_stats["reached_max_treedepth"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you mentioned this isnt ready for review yet, but I noticed this bug when trying to skim through at a high level so figured I might as well call it out

its checking for a tree_depth key but then selecting reached_max_treedepth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, in the first version I followed cmdstan too closely and I used tree_depth.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ready for review, if you have more comments. I was just waiting for feedback before adding tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add reached_max_treedepth to https://python.arviz.org/en/stable/schema/schema.html#sample-stats. It is not there so I would assume it is PyMC only

n_max = int((reached_max_treedepth).sum().values)

diagnostics_results["treedepth"] = {
"n_max": n_max,
"pct": 100 * n_max / total_samples,
"total_samples": total_samples,
}

messages.append("\nTree depth")
if n_max:
has_errors = True
pct = diagnostics_results["treedepth"]["pct"]
messages.append(
f"{n_max} of {total_samples} ({pct:.2f}%) transitions hit the maximum treedepth "
"limit.\n"
"Trajectories that are prematurely terminated due to this limit will result "
"in slow exploration.\nFor optimal performance, increase this limit."
)
else:
messages.append("Treedepth satisfactory for all transitions.")

# Check E-BFMI
if sample_stats_available and "energy" in sample_stats:
bfmi_values = bfmi(dt, sample_dims=sample_dims)["energy"]

low_bfmi = bfmi_values < 0.3
chain_indices = low_bfmi.where(low_bfmi, drop=True).coords["chain"].values.tolist()

diagnostics_results["bfmi"] = {
"bfmi_values": bfmi_values.values,
"failed_chains": chain_indices,
"threshold": bfmi_threshold,
}

messages.append("\nE-BFMI")
if chain_indices:
has_errors = True
for chain_idx in chain_indices:
bfmi_val = bfmi_values.sel(chain=chain_idx).item()
messages.append(f"Chain {chain_idx}: E-BFMI = {bfmi_val:.3f}")
messages.append(
f"E-BFMI values are below the threshold {bfmi_threshold:.2f} which suggests that "
"HMC may have trouble exploring the target distribution.\n"
"If possible, try to reparameterize the model."
)
else:
messages.append("E-BFMI satisfactory for all chains.")

# Check ESS
ess_bulk = ess(
dt,
sample_dims=sample_dims,
group=group,
var_names=var_names,
filter_vars=filter_vars,
coords=coords,
method="bulk",
)
ess_tail = ess(
dt,
sample_dims=sample_dims,
group=group,
var_names=var_names,
filter_vars=filter_vars,
coords=coords,
method="tail",
)

ess_min = np.minimum(ess_bulk.dataset, ess_tail.dataset)
ess_ratio = ess_min / total_samples
bad_ess_params = [var for var in ess_ratio.data_vars if (ess_ratio[var] < ess_min_ratio).any()]

ess_threshold = 100 * len(posterior.coords["chain"])
below_minimum_params = [
var for var in ess_bulk.ds.data_vars if (ess_bulk.ds[var] < ess_threshold).any()
]

diagnostics_results["ess"] = {
"bad_params": bad_ess_params,
"ess_values": ess_min[bad_ess_params],
"threshold_ratio": ess_min_ratio,
"total_samples": total_samples,
}

messages.append("\nESS")
if bad_ess_params:
has_errors = True
messages.append(
f"The following parameters has fewer than {ess_min_ratio:.3f} effective draws per "
f"transition:\n {', '.join(bad_ess_params)}\n"
"Such low values indicate that the effective sample size estimators may be "
"biased high and actual performance may be substantially lower than quoted."
)

if below_minimum_params:
has_errors = True
messages.append(
f"The following parameters has fewer than {ess_threshold} effective samples:\n"
f" {', '.join(below_minimum_params)}\n"
"This suggests that the sampler may not have fully explored the posterior "
"distribution for this parameter.\nConsider reparameterizing the model or "
"increasing the number of samples."
)

if not bad_ess_params and not below_minimum_params:
messages.append("Effective sample size satisfactory for all parameters.")

# Check R-hat
rhat_rank = rhat(
dt,
sample_dims=sample_dims,
group=group,
var_names=var_names,
filter_vars=filter_vars,
coords=coords,
method="rank",
)
rhat_folded = rhat(
dt,
sample_dims=sample_dims,
group=group,
var_names=var_names,
filter_vars=filter_vars,
coords=coords,
method="folded",
)

rhat_max_vals = np.maximum(rhat_rank.dataset, rhat_folded.dataset)
bad_rhat_params = [
var for var in rhat_max_vals.data_vars if (rhat_max_vals[var] > rhat_max).any()
]

diagnostics_results["rhat"] = {
"bad_params": bad_rhat_params,
"rhat_values": rhat_max_vals[bad_rhat_params],
"threshold": rhat_max,
}

messages.append("\nR-hat")
if bad_rhat_params:
has_errors = True
messages.append(
f"The following parameters has R-hat values greater than {rhat_max:.2f}:\n"
f" {', '.join(bad_rhat_params)}\n"
"Such high values indicate incomplete mixing and biased estimation.\n"
"You should consider regularizing your model with additional prior information or "
"a more effective parameterization."
)
else:
messages.append("R-hat values satisfactory for all parameters.")

if not has_errors:
messages.append("\nProcessing complete, no problems detected.")

if show_diagnostics:
print("\n".join(messages), file=sys.stdout)

if return_diagnostics:
return has_errors, diagnostics_results

return has_errors