Fix bug in gradient of Blockwise'd Scan#1482
Conversation
| # Obtain core_op gradients | ||
| with config.change_flags(compute_test_value="off"): | ||
| safe_inputs = [ | ||
| tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) |
There was a problem hiding this comment.
This line was the problematic one: shape=(None,) * len(sig)
| # FIXME: These core_outputs do not depend on core_inputs, not pretty | ||
| # It's not neccessarily a problem because if they are referenced by the gradient, | ||
| # they get replaced later in vectorize. But if the Op was to make any decision | ||
| # by introspecting the dependencies of output on inputs it would fail badly! |
7f6d58f to
61c0bf6
Compare
61c0bf6 to
1401b84
Compare
There was a problem hiding this comment.
Pull Request Overview
This PR fixes a gradient bug in Blockwise when batching scans by refactoring the L_op implementation and adds targeted tests for core‐type gradients and scan gradients.
- Refactored
Blockwise.L_opto simplify and correct core gradient extraction and batching logic. - Renamed
test_optomy_test_opin existing tests and added two new tests:test_blockwise_grad_core_typetest_scan_gradient_core_type
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| tests/tensor/test_blockwise.py | Renamed test_op, imported scan, and added two new gradient tests. |
| pytensor/tensor/blockwise.py | Completely refactored the L_op method to remove the old helper and improve batching of core gradients. |
Comments suppressed due to low confidence (1)
pytensor/tensor/blockwise.py:353
- The new
core_inputscomprehension no longer preservesNullTypeorDisconnectedTypeinputs as the old_bgradhelper did viaas_core. If an input is a null/disconnected gradient, it should be passed through unchanged, otherwise downstream gradient logic may break.
core_inputs = [
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1482 +/- ##
==========================================
- Coverage 82.01% 82.00% -0.01%
==========================================
Files 214 214
Lines 50426 50414 -12
Branches 8903 8902 -1
==========================================
- Hits 41355 41343 -12
Misses 6863 6863
Partials 2208 2208
🚀 New features to boost your workflow:
|
AlexAndorra
left a comment
There was a problem hiding this comment.
Thanks a lot @ricardoV94 !! Definitely unblocks us over on pymc-extras 🤩
I'd say it was much more than a one-liner though 😉
The fix was a one liner, I just cleaned up stuff besides it. Check PRs commit by commit and you'll see it ;) |
|
@ricardoV94 , using this over on [(d__logp/dP0_diag_log__),
(d__logp/dinitial_trend),
(d__logp/dar_params_logodds__),
(d__logp/dsigma_trend_log__),
(d__logp/dsigma_ar_log__),
(d__logp/dsigma_obs_log__)]But now... The following example will trigger the error, just using a small dataset of 15 data points and a batch size of 5 (one per president)(agg.csv). I'm running pytensor main ( import numpy as np
import pandas as pd
import pymc as pm
import pymc_extras.statespace as pmss
import pytensor
import pytensor.tensor as pt
import xarray as xr
presidents = agg.president.unique()
mod = pmss.structural.LevelTrendComponent(order=2, innovations_order=[0, 1])
mod += pmss.structural.AutoregressiveComponent(order=1)
mod += pmss.structural.MeasurementError(name="obs")
ss_mod = mod.build(
name="president",
batch_coords={"president": presidents}, # this is gonna be leftmost dimension
)
ss_array = (
agg.set_index(["president", "month_id"])["approve_pr"].unstack("month_id").to_numpy()[..., None]
) # dims=(president, timesteps, obs_dim)
initial_trend_dims, sigma_trend_dims, ar_param_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords
with pm.Model(coords=coords | ss_mod.batch_coords) as model_1:
P0_diag = pm.Gamma("P0_diag", alpha=5, beta=5, dims="president")
P0 = pm.Deterministic(
"P0", pt.eye(ss_mod.k_states)[None] * P0_diag[..., None, None], dims=("president", *P0_dims)
)
initial_trend = pm.Normal("initial_trend", dims=("president", *initial_trend_dims))
ar_params = pm.Beta("ar_params", alpha=3, beta=3, dims=("president", *ar_param_dims))
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=50, dims=("president", *sigma_trend_dims))
sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=5, dims="president")
sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.05, dims="president")
ss_mod.build_statespace_graph(ss_array)
idata = pm.sample() # nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"})This (or the The Kernel crashed while executing code in the current cell or a previous cell.
Please review the code in the cell(s) to identify a possible cause of the failure. |
This PR fixes a bug found in the work to add batch dimensions to the Statespace module in pymc-extras. The gradient of a Blockwise'd scan with a specific broadcastable signature in the inner graph (i.e., shape=(1,)), was failing, because Blockwise was creating a dummy node with variables that didn't respect the core shapes (i.e., shape=(None,)).
The fix is a one liner (second commit).
The last commit refactors the L_op implementation since the helper function isn't used anywhere else, and some parts that were copied from Elemwise don't make sense (such as worrying that core.op L_op might return
None).📚 Documentation preview 📚: https://pytensor--1482.org.readthedocs.build/en/1482/