feat: generate adjoint_rhs for continuous ODE models#173
Draft
sdwfrost wants to merge 6 commits intomrc-ide:mainfrom
Draft
feat: generate adjoint_rhs for continuous ODE models#173sdwfrost wants to merge 6 commits intomrc-ide:mainfrom
sdwfrost wants to merge 6 commits intomrc-ide:mainfrom
Conversation
Add symbolic Jacobian transpose generation for continuous (ODE) models with differentiate=TRUE, enabling gradient-based inference (HMC) for ODE models through dust2's new continuous adjoint infrastructure. Changes to parse_adjoint.R: - Fix adjoint packing order: force [state adjoints, param adjoints] layout so gradient extraction at offset n_state works correctly - Add adjoint_deriv() function: generates adjoint equations for RHS by differentiating deriv equations w.r.t. state and parameters - Fix adjoint_phase() filtering: include all non-stack (output) equations, not just those referenced by other equations Changes to generate_dust.R: - Add generate_dust_system_adjoint_rhs(): emits C++ adjoint_rhs() static method computing J^T * lambda (no negation - backward time convention handles sign) - Include adjoint_rhs in the generated adjoint methods block All 254 existing odin2 tests pass (0 failures). End-to-end HMC sampling verified: beta=0.5002 (true=0.5), gamma=0.0962 (true=0.1). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Enable gradient computation for models with array state variables, array intermediates, and 2D arrays. Changes span the parse and code generation stages: parse_adjoint.R: - Remove 'Implement differentiation with arrays' blocker - Build adjoint array metadata: adj_X inherits dims from X - Handle indexed adjoint references (adj_S[i], adj_S[i,j]) - Add reduce_loops for scalar-from-array accumulation - Unwrap OdinReduce to sum(X[]) for monty differentiator - Store array intermediate adjoints in internal (not stack) - Exclude dim_* equations from adjoint intermediate lists - Preserve state-before-parameter packing order generate_dust.R: - Fix adjoint array LHS indexing with expr_plus(idx, offset) - Handle reduce_loops with nested for loops and += accumulation - Strip const declaration for stack intermediates in reduce_loops - Add adjoint_split_accumulate() helper Tested: 1D arrays (3-group, 5-group), 2D arrays (age x vax), discrete-time arrays, array parameters. Gradient matches finite differences to ~1e-7 (continuous) / ~1e-9 (discrete). All 1083 existing tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…lex models - Add digamma to FUNCTIONS_MONTY_MATH in constants.R - Add adjoint_rewrap_sums() to convert bare sum() calls back to OdinReduce - Add adjoint_make_reduce() helper for proper OdinReduce construction - Handle partial index sums with OdinDim range entries - All 1083 tests pass Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Two fixes for complex array adjoint models: 1. Recursive dependency resolution in adjoint_phase(): the previous one-level-deep scan missed state variables referenced transitively through intermediate equations (e.g., D used by target_met_adults_t used by adjoint equations). Now iterates until convergence. 2. Free index variable reduction in generate_dust_assignment(): when a lower-dimensional adjoint variable depends on higher-dimensional arrays (e.g., adj_prop_infectious[i] from s_ij_sex[i,j]), the extra index variables need reduction loops. Added adjoint_find_free_indices() to detect free indices and generate zero-init + accumulation pattern. mpoxseir (1147-line discrete-time model, 47+ parameters, 2D arrays) now compiles with adjoint support. All 1083 tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…hain Two bugs in the array adjoint code generation: 1. Kronecker delta free-index resolution: when s_ij[i,j] = m[i,j] * prop_I[j], the adjoint for prop_I produced adj_prop_I[i] = sum_j(adj_s_ij[i,j] * m[i,j] * delta(j,i)) which only captures diagonal contributions. Fixed by detecting delta patterns, removing them, and swapping subscript index variables so the reduction covers all off-diagonal contributions correctly. 2. Shared intermediate adjoint chain: variables like p_IR = 1 - exp(-gamma) stored in shared state were excluded from adjoint equation generation, breaking the gradient chain adj_n_IR -> adj_p_IR -> adj_gamma. Fixed by finding shared intermediates on the dependency path from update equations to differentiated parameters and including them in adjoint generation. Also adds test-adjoint-stan.R with 24 tests comparing odin symbolic adjoint against Stan Math reverse-mode AD and finite differences for scalar SIR, 3-group mixing SIR, and 2D (age x vaccination) SIR models. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…tting Three fixes for array parameter adjoint code generation: 1. Array parameter adjoint packing (parse_adjoint.R): adj_beta, adj_gamma, adj_I0 were packed as scalars (rank 0) instead of arrays matching their original parameter dimensions. Added array table entries for parameter adjoints so they get correct rank, dims, and alias information. 2. Mixed free-index/direct term splitting (generate_dust.R): When an adjoint expression combines direct terms (adj_I[i]*(1-gamma[i]*dt)) with free-index terms (sum_j adj_s_ij[j,i]*beta[j,i]/N_pop[i]), the code generator now splits them: direct terms go outside the reduction loop, free-index terms go inside with Kronecker delta resolution. 3. Helper functions: adjoint_split_additive(), adjoint_expr_contains_free_index(), adjoint_expr_mentions(), adjoint_reconstruct_sum() for expression splitting. Verified: 5-group age-structured SIR with 35 array parameters matches Stan Math reverse-mode AD to relative error 1e-14. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
adjoint_rhscode generation for continuous-time (ODE) models, enabling symbolic backward adjoint gradient computation.What this does
For models using
deriv(), odin2 now generates anadjoint_rhs()C++ method that computes the Jacobian-transpose-vector product(∂f/∂y)ᵀ λsymbolically. This is used by dust2's backward ODE integrator (see mrc-ide/dust2#167) to compute exact gradients via the continuous adjoint method.Changes
R/generate_dust.R: Newgenerate_dust_adjoint_rhs()function that emits theadjoint_rhs()static method. Generates symbolic partial derivatives of the RHS w.r.t. state variables and parameters, producing efficient C++ code for the JᵀV product.R/parse_adjoint.R: Extended to handle continuous models — builds dependency graphs for ODE RHS terms, supports array models with Kronecker delta handling, recursive dependencies, and mixed shared/intermediate terms.R/constants.R: Addeddigammato monty math function mappings.tests/testthat/test-adjoint-stan.R: New test file (367 lines) validating generated gradients against Stan math library reference values for scalar and array ODE models.Features
digammafunction support for models using Beta/Gamma distributionsDependencies
This PR generates code that is consumed by dust2's continuous adjoint integrator. The companion dust2 PR is mrc-ide/dust2#167 (DP5-based backward adjoint with
zero_everyboundary handling). dust2 detectsadjoint_rhsvia SFINAE and falls back to finite differences when it's not available, so these PRs can be merged independently.Testing
test-adjoint-stan.Rvalidates symbolic gradients against Stan math reference for SIR ODE and array SEIR models