-
-
Notifications
You must be signed in to change notification settings - Fork 13
Add diagnose #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add diagnose #299
Changes from all commits
d5fbf7f
2c45f62
9c3b567
b4df4ee
09d248e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,10 @@ | ||
| """Functions for sampling diagnostics.""" | ||
|
|
||
| import sys | ||
|
|
||
| import numpy as np | ||
| import xarray as xr | ||
| from arviz_base import convert_to_dataset | ||
| from arviz_base import convert_to_dataset, convert_to_datatree, rcParams | ||
|
|
||
| from arviz_stats.utils import _apply_multi_input_function, get_array_function | ||
| from arviz_stats.validate import validate_dims | ||
|
|
@@ -615,3 +617,324 @@ def bfmi( | |
| coords=coords, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
| def diagnose( | ||
| data, | ||
| *, | ||
| var_names=None, | ||
| filter_vars=None, | ||
| coords=None, | ||
| sample_dims=None, | ||
| group="posterior", | ||
| rhat_max=1.01, | ||
| ess_min_ratio=0.001, | ||
| bfmi_threshold=0.3, | ||
| show_diagnostics=True, | ||
| return_diagnostics=False, | ||
| ): | ||
| """Run comprehensive diagnostic checks for MCMC sampling. | ||
|
|
||
| This function performs diagnostic checks on MCMC samples similar to CmdStan's diagnose | ||
| utility. It checks for: | ||
|
|
||
| - Divergent transitions | ||
| - Maximum tree depth saturation | ||
| - Low E-BFMI (Energy Bayesian Fraction of Missing Information) | ||
| - Low effective sample size (ESS) | ||
| - High R-hat values | ||
|
|
||
| See [1]_ and [2]_ for more details. You can also check https://arviz-devs.github.io/EABM/Chapters/MCMC_diagnostics.html | ||
| for a more practical overview. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : DataTree, Dataset, or InferenceData-like | ||
| Input data. To be able to compute all diagnostics, the data should contain MCMC | ||
| posterior samples (see ``group`` argument) and sampler statistics (in "sample_stats" group). | ||
| var_names : str or list of str, optional | ||
| Names of variables to check for R-hat and ESS diagnostics. | ||
| If None, checks all variables. | ||
| filter_vars : {None, "like", "regex"}, default None | ||
| How to filter variable names. See :func:`filter_vars` for details. | ||
| coords : dict, optional | ||
| Coordinates to select a subset of the data. | ||
| sample_dims : iterable of hashable, optional | ||
| Dimensions to be considered sample dimensions. | ||
| Default from ``rcParams["data.sample_dims"]``. | ||
| group : str, default "posterior" | ||
| Group to check for convergence diagnostics (R-hat, ESS). | ||
| rhat_max : float, default 1.01 | ||
| Maximum acceptable R-hat value. Parameters with R-hat > rhat_max | ||
| will be flagged. | ||
| ess_min_ratio : float, default 0.001 | ||
| Minimum acceptable ratio of ESS to total samples. Parameters with | ||
| ESS/N < ess_min_ratio will be flagged. | ||
| A flag is also emitted if ESS is lower than 100 * number of chains. | ||
| bfmi_threshold : float, default 0.3 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have this as an rcParam I think now |
||
| Minimum acceptable E-BFMI value. Values below this threshold indicate | ||
| potential issues with the sampler's exploration. | ||
| show_diagnostics : bool, default True | ||
| If True, print diagnostic messages to stdout. If False, return results silently. | ||
| return_diagnostics : bool, default False | ||
| If True, return a dictionary with detailed diagnostic results in addition | ||
| to the boolean has_errors flag. | ||
|
|
||
| Returns | ||
| ------- | ||
| has_errors : bool | ||
| True if any diagnostic checks failed, False otherwise. | ||
| diagnostics : dict, optional | ||
| Only returned if return_diagnostics=True. | ||
|
|
||
| - "divergent": dict with keys "n_divergent", "pct", "total_samples" | ||
| - "treedepth": dict with keys "n_max", "pct", "total_samples" | ||
| - "bfmi": dict with keys "bfmi_values", "failed_chains", "threshold" | ||
| - "ess": dict with keys "bad_params", "ess_values", "threshold_ratio", "total_samples" | ||
| - "rhat": dict with keys "bad_params", "rhat_values", "threshold" | ||
|
|
||
| Examples | ||
| -------- | ||
| Get diagnostics printted to stdout: | ||
|
|
||
| .. ipython:: | ||
|
|
||
| In [1]: import arviz_stats as azs | ||
| ...: from arviz_base import load_arviz_data | ||
| ...: data = load_arviz_data('centered_eight') | ||
| ...: azs.diagnose(data) | ||
|
|
||
| Get detailed diagnostic information without printing messages: | ||
|
|
||
| .. ipython:: | ||
|
|
||
| In [1]: _, diagnostics = azs.diagnose(data, return_diagnostics=True, show_diagnostics=False) | ||
| ...: diagnostics | ||
|
|
||
| See Also | ||
| -------- | ||
| rhat : Compute R-hat convergence diagnostic | ||
| ess : Compute effective sample size | ||
| bfmi : Compute Bayesian fraction of missing information | ||
| summary : Create a data frame with summary statistics, including diagnostics. | ||
|
|
||
| References | ||
| ---------- | ||
| .. [1] Vehtari et al. *Rank-normalization, folding, and localization: An improved Rhat for | ||
| assessing convergence of MCMC*. Bayesian Analysis. 16(2) (2021) | ||
| https://doi.org/10.1214/20-BA1221. arXiv preprint https://arxiv.org/abs/1903.08008 | ||
| .. [2] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in | ||
| Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695 | ||
| """ | ||
| if sample_dims is None: | ||
| sample_dims = rcParams["data.sample_dims"] | ||
| if isinstance(sample_dims, str): | ||
| sample_dims = [sample_dims] | ||
|
|
||
| dt = convert_to_datatree(data) | ||
|
|
||
| has_errors = False | ||
| diagnostics_results = {} | ||
| messages = [] | ||
|
|
||
| sample_stats = dt.get("sample_stats") | ||
| if sample_stats is None: | ||
| messages.append("No sample_stats group found. Skipping sampler-specific diagnostics.") | ||
| sample_stats_available = False | ||
| else: | ||
| sample_stats_available = True | ||
|
|
||
| posterior = dt[group] | ||
|
|
||
| total_samples = np.prod([posterior.sizes[dim] for dim in sample_dims if dim in posterior.sizes]) | ||
|
|
||
| # Check divergences | ||
| if sample_stats_available and "diverging" in sample_stats: | ||
| diverging = sample_stats["diverging"] | ||
| n_divergent = int(diverging.sum().values) | ||
|
|
||
| diagnostics_results["divergent"] = { | ||
| "n_divergent": n_divergent, | ||
| "pct": 100 * n_divergent / total_samples, | ||
| "total_samples": total_samples, | ||
| } | ||
|
|
||
| messages.append("Divergences") | ||
| if n_divergent > 0: | ||
| has_errors = True | ||
| pct = diagnostics_results["divergent"]["pct"] | ||
| messages.append( | ||
| f"{n_divergent} of {total_samples} ({pct:.2f}%) transitions ended with a " | ||
| "divergence.\n" | ||
| "These divergent transitions indicate that HMC is not fully able to explore " | ||
| "the posterior distribution.\n" | ||
| "Try increasing adapt delta closer to 1.\n" | ||
| "If this doesn't remove all divergences, try to reparameterize the model." | ||
| ) | ||
| else: | ||
| messages.append("No divergent transitions found.") | ||
|
|
||
| # Check tree depth | ||
| if sample_stats_available and "reached_max_treedepth" in sample_stats: | ||
| reached_max_treedepth = sample_stats["reached_max_treedepth"] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know you mentioned this isnt ready for review yet, but I noticed this bug when trying to skim through at a high level so figured I might as well call it out its checking for a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, in the first version I followed cmdstan too closely and I used tree_depth.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is ready for review, if you have more comments. I was just waiting for feedback before adding tests.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add |
||
| n_max = int((reached_max_treedepth).sum().values) | ||
|
|
||
| diagnostics_results["treedepth"] = { | ||
| "n_max": n_max, | ||
| "pct": 100 * n_max / total_samples, | ||
| "total_samples": total_samples, | ||
| } | ||
|
|
||
| messages.append("\nTree depth") | ||
| if n_max: | ||
| has_errors = True | ||
| pct = diagnostics_results["treedepth"]["pct"] | ||
| messages.append( | ||
| f"{n_max} of {total_samples} ({pct:.2f}%) transitions hit the maximum treedepth " | ||
| "limit.\n" | ||
| "Trajectories that are prematurely terminated due to this limit will result " | ||
| "in slow exploration.\nFor optimal performance, increase this limit." | ||
| ) | ||
| else: | ||
| messages.append("Treedepth satisfactory for all transitions.") | ||
|
|
||
| # Check E-BFMI | ||
| if sample_stats_available and "energy" in sample_stats: | ||
| bfmi_values = bfmi(dt, sample_dims=sample_dims)["energy"] | ||
|
|
||
| low_bfmi = bfmi_values < 0.3 | ||
| chain_indices = low_bfmi.where(low_bfmi, drop=True).coords["chain"].values.tolist() | ||
|
|
||
| diagnostics_results["bfmi"] = { | ||
| "bfmi_values": bfmi_values.values, | ||
| "failed_chains": chain_indices, | ||
| "threshold": bfmi_threshold, | ||
| } | ||
|
|
||
| messages.append("\nE-BFMI") | ||
| if chain_indices: | ||
| has_errors = True | ||
| for chain_idx in chain_indices: | ||
| bfmi_val = bfmi_values.sel(chain=chain_idx).item() | ||
| messages.append(f"Chain {chain_idx}: E-BFMI = {bfmi_val:.3f}") | ||
| messages.append( | ||
| f"E-BFMI values are below the threshold {bfmi_threshold:.2f} which suggests that " | ||
| "HMC may have trouble exploring the target distribution.\n" | ||
| "If possible, try to reparameterize the model." | ||
| ) | ||
| else: | ||
| messages.append("E-BFMI satisfactory for all chains.") | ||
|
|
||
| # Check ESS | ||
| ess_bulk = ess( | ||
| dt, | ||
| sample_dims=sample_dims, | ||
| group=group, | ||
| var_names=var_names, | ||
| filter_vars=filter_vars, | ||
| coords=coords, | ||
| method="bulk", | ||
| ) | ||
| ess_tail = ess( | ||
| dt, | ||
| sample_dims=sample_dims, | ||
| group=group, | ||
| var_names=var_names, | ||
| filter_vars=filter_vars, | ||
| coords=coords, | ||
| method="tail", | ||
| ) | ||
|
|
||
| ess_min = np.minimum(ess_bulk.dataset, ess_tail.dataset) | ||
| ess_ratio = ess_min / total_samples | ||
| bad_ess_params = [var for var in ess_ratio.data_vars if (ess_ratio[var] < ess_min_ratio).any()] | ||
|
|
||
| ess_threshold = 100 * len(posterior.coords["chain"]) | ||
| below_minimum_params = [ | ||
| var for var in ess_bulk.ds.data_vars if (ess_bulk.ds[var] < ess_threshold).any() | ||
| ] | ||
|
|
||
| diagnostics_results["ess"] = { | ||
| "bad_params": bad_ess_params, | ||
| "ess_values": ess_min[bad_ess_params], | ||
| "threshold_ratio": ess_min_ratio, | ||
| "total_samples": total_samples, | ||
| } | ||
|
|
||
| messages.append("\nESS") | ||
| if bad_ess_params: | ||
| has_errors = True | ||
| messages.append( | ||
| f"The following parameters has fewer than {ess_min_ratio:.3f} effective draws per " | ||
| f"transition:\n {', '.join(bad_ess_params)}\n" | ||
| "Such low values indicate that the effective sample size estimators may be " | ||
| "biased high and actual performance may be substantially lower than quoted." | ||
| ) | ||
|
|
||
| if below_minimum_params: | ||
| has_errors = True | ||
| messages.append( | ||
| f"The following parameters has fewer than {ess_threshold} effective samples:\n" | ||
| f" {', '.join(below_minimum_params)}\n" | ||
| "This suggests that the sampler may not have fully explored the posterior " | ||
| "distribution for this parameter.\nConsider reparameterizing the model or " | ||
| "increasing the number of samples." | ||
| ) | ||
|
|
||
| if not bad_ess_params and not below_minimum_params: | ||
| messages.append("Effective sample size satisfactory for all parameters.") | ||
|
|
||
| # Check R-hat | ||
| rhat_rank = rhat( | ||
| dt, | ||
| sample_dims=sample_dims, | ||
| group=group, | ||
| var_names=var_names, | ||
| filter_vars=filter_vars, | ||
| coords=coords, | ||
| method="rank", | ||
| ) | ||
| rhat_folded = rhat( | ||
| dt, | ||
| sample_dims=sample_dims, | ||
| group=group, | ||
| var_names=var_names, | ||
| filter_vars=filter_vars, | ||
| coords=coords, | ||
| method="folded", | ||
| ) | ||
|
|
||
| rhat_max_vals = np.maximum(rhat_rank.dataset, rhat_folded.dataset) | ||
| bad_rhat_params = [ | ||
| var for var in rhat_max_vals.data_vars if (rhat_max_vals[var] > rhat_max).any() | ||
| ] | ||
|
|
||
| diagnostics_results["rhat"] = { | ||
| "bad_params": bad_rhat_params, | ||
| "rhat_values": rhat_max_vals[bad_rhat_params], | ||
| "threshold": rhat_max, | ||
| } | ||
|
|
||
| messages.append("\nR-hat") | ||
| if bad_rhat_params: | ||
| has_errors = True | ||
| messages.append( | ||
| f"The following parameters has R-hat values greater than {rhat_max:.2f}:\n" | ||
| f" {', '.join(bad_rhat_params)}\n" | ||
| "Such high values indicate incomplete mixing and biased estimation.\n" | ||
| "You should consider regularizing your model with additional prior information or " | ||
| "a more effective parameterization." | ||
| ) | ||
| else: | ||
| messages.append("R-hat values satisfactory for all parameters.") | ||
|
|
||
| if not has_errors: | ||
| messages.append("\nProcessing complete, no problems detected.") | ||
|
|
||
| if show_diagnostics: | ||
| print("\n".join(messages), file=sys.stdout) | ||
|
|
||
| if return_diagnostics: | ||
| return has_errors, diagnostics_results | ||
|
|
||
| return has_errors | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we could take a page from the plots side here and allow a
statsdictionary argument. That being said, here I would only allow providing pre-computed elements to keep things simple. For example:In that particular case it would make it easier to take advantage of
diagnoseeven if wanting to use rhat_nested instead of regular rhat. They could also change the probability at which to computeess_tail