-
Notifications
You must be signed in to change notification settings - Fork 82
New ADVI API #635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jessegrabowski
wants to merge
8
commits into
pymc-devs:main
Choose a base branch
from
jessegrabowski:advi-refactor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
New ADVI API #635
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
8d1f528
Move stuff over from pymc
jessegrabowski 1423593
Training helper sketch
jessegrabowski b5c2ca6
Add training example
jessegrabowski c530923
Progress bar and graceful keyboard interrupt
jessegrabowski d7e9c6f
Fix STL estimator
jessegrabowski bfdb98b
Add radon example
jessegrabowski 6b612f1
Move compile functions to compile.py
jessegrabowski f3d1759
Update notebook
jessegrabowski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright 2025 - present The PyMC Developers | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from dataclasses import dataclass, field | ||
|
|
||
| import numpy as np | ||
| import pytensor.tensor as pt | ||
|
|
||
| from pymc.distributions import Normal | ||
| from pymc.logprob.basic import conditional_logp | ||
| from pymc.model.core import Deterministic, Model | ||
| from pytensor import graph_replace | ||
| from pytensor.gradient import disconnected_grad | ||
| from pytensor.graph.basic import Variable | ||
|
|
||
| from pymc_extras.inference.advi.pytensorf import get_symbolic_rv_shapes | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class AutoGuideModel: | ||
| model: Model | ||
| params_init_values: dict[Variable, np.ndarray] | ||
| name_to_param: dict[str, Variable] = field(init=False) | ||
|
|
||
| def __post_init__(self): | ||
| object.__setattr__( | ||
| self, | ||
| "name_to_param", | ||
| {x.name: x for x in self.params_init_values.keys()}, | ||
| ) | ||
|
|
||
| @property | ||
| def params(self) -> tuple[Variable, ...]: | ||
| return tuple(self.params_init_values.keys()) | ||
|
|
||
| def __getitem__(self, name: str) -> Variable: | ||
| return self.name_to_param[name] | ||
|
|
||
| def stochastic_logq(self, stick_the_landing: bool = True) -> pt.TensorVariable: | ||
| """Returns a graph representing the logp of the guide model, evaluated under draws from its random variables.""" | ||
| # This allows arbitrary | ||
| logp_terms = conditional_logp( | ||
| {rv: rv for rv in self.model.deterministics}, | ||
| warn_rvs=False, | ||
| ) | ||
| logq = pt.sum([logp_term.sum() for logp_term in logp_terms.values()]) | ||
|
|
||
| if stick_the_landing: | ||
| # Detach variational parameters from the gradient computation of logq | ||
| repl = {p: disconnected_grad(p) for p in self.params} | ||
| logq = graph_replace(logq, repl) | ||
|
|
||
| return logq | ||
|
|
||
|
|
||
| def AutoDiagonalNormal(model) -> AutoGuideModel: | ||
| coords = model.coords | ||
| free_rvs = model.free_RVs | ||
|
|
||
| free_rv_shapes = dict(zip(free_rvs, get_symbolic_rv_shapes(free_rvs))) | ||
| params_init_values = {} | ||
|
|
||
| with Model(coords=coords) as guide_model: | ||
| for rv in free_rvs: | ||
| loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) | ||
| scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape) | ||
| # TODO: Make these customizable | ||
| params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval() | ||
| params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval() | ||
|
|
||
| z = Normal( | ||
| f"{rv.name}_z", | ||
| mu=0, | ||
| sigma=1, | ||
| shape=free_rv_shapes[rv], | ||
| ) | ||
| Deterministic( | ||
| rv.name, | ||
| loc + pt.softplus(scale) * z, | ||
| dims=model.named_vars_to_dims.get(rv.name, None), | ||
| ) | ||
|
|
||
| return AutoGuideModel(guide_model, params_init_values) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| from typing import Protocol | ||
|
|
||
| import numpy as np | ||
|
|
||
| from pymc import Model, compile | ||
| from pymc.pytensorf import rewrite_pregrad | ||
| from pytensor import tensor as pt | ||
|
|
||
| from pymc_extras.inference.advi.autoguide import AutoGuideModel | ||
| from pymc_extras.inference.advi.objective import advi_objective, get_logp_logq | ||
| from pymc_extras.inference.advi.pytensorf import vectorize_random_graph | ||
|
|
||
|
|
||
| class TrainingFn(Protocol): | ||
| def __call__(self, draws: int, *params: np.ndarray) -> tuple[np.ndarray, ...]: ... | ||
|
|
||
|
|
||
| def compile_svi_training_fn( | ||
| model: Model, | ||
| guide: AutoGuideModel, | ||
| stick_the_landing: bool = True, | ||
| minibatch: bool = False, | ||
| **compile_kwargs, | ||
| ) -> TrainingFn: | ||
| draws = pt.scalar("draws", dtype=int) | ||
| params = guide.params | ||
| inputs = [draws, *params] | ||
|
|
||
| logp_scale = 1 | ||
|
|
||
| if minibatch: | ||
| data = model.data_vars | ||
| inputs = [*inputs, *data] | ||
|
|
||
| logp, logq = get_logp_logq(model, guide, stick_the_landing=stick_the_landing) | ||
|
|
||
| scalar_negative_elbo = advi_objective(logp / logp_scale, logq) | ||
| [negative_elbo_draws] = vectorize_random_graph([scalar_negative_elbo], batch_draws=draws) | ||
| negative_elbo = negative_elbo_draws.mean(axis=0) | ||
|
|
||
| negative_elbo_grads = pt.grad(rewrite_pregrad(negative_elbo), wrt=params) | ||
|
|
||
| if "trust_input" not in compile_kwargs: | ||
| compile_kwargs["trust_input"] = True | ||
|
|
||
| f_loss_dloss = compile( | ||
| inputs=inputs, outputs=[negative_elbo, *negative_elbo_grads], **compile_kwargs | ||
| ) | ||
|
|
||
| return f_loss_dloss | ||
|
|
||
|
|
||
| def compile_sampling_fn(model: Model, guide: AutoGuideModel, **compile_kwargs) -> TrainingFn: | ||
| draws = pt.scalar("draws", dtype=int) | ||
| params = guide.params | ||
|
|
||
| parameterized_value_vars = [ | ||
| guide.model[rv.name] for rv in model.rvs_to_values.keys() if rv not in model.observed_RVs | ||
| ] | ||
| transformed_vars = [ | ||
| transform.backward(parameterized_var) | ||
| if (transform := model.rvs_to_transforms[rv]) is not None | ||
| else parameterized_var | ||
| for rv, parameterized_var in zip(model.rvs_to_values.keys(), parameterized_value_vars) | ||
| ] | ||
|
|
||
| sampled_rvs_draws = vectorize_random_graph(transformed_vars, batch_draws=draws) | ||
|
|
||
| if "trust_input" not in compile_kwargs: | ||
| compile_kwargs["trust_input"] = True | ||
|
|
||
| f_sample = compile(inputs=[draws, *params], outputs=sampled_rvs_draws, **compile_kwargs) | ||
|
|
||
| return f_sample |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from pymc import Model | ||
| from pytensor import graph_replace | ||
| from pytensor.tensor import TensorVariable | ||
|
|
||
| from pymc_extras.inference.advi.autoguide import AutoGuideModel | ||
|
|
||
|
|
||
| def get_logp_logq(model: Model, guide: AutoGuideModel, stick_the_landing: bool = True): | ||
| """ | ||
| Compute the log probability of the model and the guide. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Model | ||
| The probabilistic model. | ||
| guide : AutoGuideModel | ||
| The variational guide. | ||
| stick_the_landing : bool, optional | ||
| Whether to use the stick-the-landing (STL) gradient estimator, by default True. | ||
| The STL estimator has lower gradient variance by removing the score function term | ||
| from the gradient. When True, gradients are stopped from flowing through logq. | ||
|
|
||
| Returns | ||
| ------- | ||
| logp : TensorVariable | ||
| Log probability of the model. | ||
| logq : TensorVariable | ||
| Log probability of the guide. | ||
| """ | ||
|
|
||
| inputs_to_guide_rvs = { | ||
| model_value_var: guide.model[rv.name] | ||
| for rv, model_value_var in model.rvs_to_values.items() | ||
| if rv not in model.observed_RVs | ||
| } | ||
|
|
||
| logp = graph_replace(model.logp(), inputs_to_guide_rvs) | ||
| logq = guide.stochastic_logq(stick_the_landing=stick_the_landing) | ||
|
|
||
| return logp, logq | ||
|
|
||
|
|
||
| def advi_objective(logp: TensorVariable, logq: TensorVariable): | ||
| """Compute the negative ELBO objective for ADVI. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| logp : TensorVariable | ||
| Log probability of the model. | ||
| logq : TensorVariable | ||
| Log probability of the guide. | ||
|
|
||
| Returns | ||
| ------- | ||
| TensorVariable | ||
| The negative ELBO. | ||
| """ | ||
| negative_elbo = logq - logp | ||
| return negative_elbo | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
| from typing import TYPE_CHECKING, cast | ||
|
|
||
| from pymc import SymbolicRandomVariable | ||
| from pymc.distributions.shape_utils import change_dist_size | ||
| from pytensor import config | ||
| from pytensor import tensor as pt | ||
| from pytensor.graph import FunctionGraph, ancestors, vectorize_graph | ||
| from pytensor.tensor import TensorLike, TensorVariable | ||
| from pytensor.tensor.basic import infer_shape_db | ||
| from pytensor.tensor.random.op import RandomVariable | ||
| from pytensor.tensor.rewriting.shape import ShapeFeature | ||
|
|
||
| if TYPE_CHECKING: | ||
| pass | ||
|
|
||
|
|
||
| def vectorize_random_graph( | ||
| graph: Sequence[TensorVariable], batch_draws: TensorLike | ||
| ) -> list[TensorVariable]: | ||
| # Find the root random nodes | ||
| rvs = tuple( | ||
| var | ||
| for var in ancestors(graph) | ||
| if ( | ||
| var.owner is not None | ||
| and isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable) | ||
| ) | ||
| ) | ||
| rvs_set = set(rvs) | ||
| root_rvs = tuple(rv for rv in rvs if not (set(rv.owner.inputs) & rvs_set)) | ||
|
|
||
| # Vectorize graph by vectorizing root RVs | ||
| batch_draws = pt.as_tensor(batch_draws, dtype=int) | ||
| vectorized_replacements = { | ||
| root_rv: change_dist_size(root_rv, new_size=batch_draws, expand=True) | ||
| for root_rv in root_rvs | ||
| } | ||
| return cast(list[TensorVariable], vectorize_graph(graph, replace=vectorized_replacements)) | ||
|
|
||
|
|
||
| def get_symbolic_rv_shapes( | ||
| rvs: Sequence[TensorVariable], raise_if_rvs_in_graph: bool = True | ||
| ) -> tuple[TensorVariable, ...]: | ||
| # TODO: Move me to pymc.pytensorf, this is needed often | ||
|
|
||
| rv_shapes = [rv.shape for rv in rvs] | ||
| shape_fg = FunctionGraph(outputs=rv_shapes, features=[ShapeFeature()], clone=True) | ||
| with config.change_flags(optdb__max_use_ratio=10, cxx=""): | ||
| infer_shape_db.default_query.rewrite(shape_fg) | ||
| rv_shapes = shape_fg.outputs | ||
|
|
||
| if raise_if_rvs_in_graph and (overlap := (set(rvs) & set(ancestors(rv_shapes)))): | ||
| raise ValueError(f"rv_shapes still depend the following rvs {overlap}") | ||
|
|
||
| return cast(tuple[TensorVariable, ...], tuple(rv_shapes)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we need a better name than
stick_the_landing. Also the function is supposed to return the logp and logq terms but the STL estimator is about returning only the path derivative component of the gradient.