Skip to content

feat: generate adjoint_rhs for continuous ODE models#173

Draft
sdwfrost wants to merge 6 commits intomrc-ide:mainfrom
sdwfrost:feat/adjoint-rhs-codegen
Draft

feat: generate adjoint_rhs for continuous ODE models#173
sdwfrost wants to merge 6 commits intomrc-ide:mainfrom
sdwfrost:feat/adjoint-rhs-codegen

Conversation

@sdwfrost
Copy link
Copy Markdown

Summary

Adds adjoint_rhs code generation for continuous-time (ODE) models, enabling symbolic backward adjoint gradient computation.

What this does

For models using deriv(), odin2 now generates an adjoint_rhs() C++ method that computes the Jacobian-transpose-vector product (∂f/∂y)ᵀ λ symbolically. This is used by dust2's backward ODE integrator (see mrc-ide/dust2#167) to compute exact gradients via the continuous adjoint method.

Changes

  • R/generate_dust.R: New generate_dust_adjoint_rhs() function that emits the adjoint_rhs() static method. Generates symbolic partial derivatives of the RHS w.r.t. state variables and parameters, producing efficient C++ code for the JᵀV product.
  • R/parse_adjoint.R: Extended to handle continuous models — builds dependency graphs for ODE RHS terms, supports array models with Kronecker delta handling, recursive dependencies, and mixed shared/intermediate terms.
  • R/constants.R: Added digamma to monty math function mappings.
  • tests/testthat/test-adjoint-stan.R: New test file (367 lines) validating generated gradients against Stan math library reference values for scalar and array ODE models.

Features

  • Scalar and array ODE model support
  • Correct handling of shared intermediates in the adjoint chain rule
  • Kronecker delta for array index derivatives
  • Mixed-term splitting for parameters appearing in both shared and state-dependent expressions
  • Recursive dependency resolution for complex model hierarchies
  • digamma function support for models using Beta/Gamma distributions

Dependencies

This PR generates code that is consumed by dust2's continuous adjoint integrator. The companion dust2 PR is mrc-ide/dust2#167 (DP5-based backward adjoint with zero_every boundary handling). dust2 detects adjoint_rhs via SFINAE and falls back to finite differences when it's not available, so these PRs can be merged independently.

Testing

  • New test-adjoint-stan.R validates symbolic gradients against Stan math reference for SIR ODE and array SEIR models
  • Existing adjoint tests continue to pass

Simon Frost and others added 6 commits March 27, 2026 08:01
Add symbolic Jacobian transpose generation for continuous (ODE) models
with differentiate=TRUE, enabling gradient-based inference (HMC) for
ODE models through dust2's new continuous adjoint infrastructure.

Changes to parse_adjoint.R:
- Fix adjoint packing order: force [state adjoints, param adjoints]
  layout so gradient extraction at offset n_state works correctly
- Add adjoint_deriv() function: generates adjoint equations for RHS
  by differentiating deriv equations w.r.t. state and parameters
- Fix adjoint_phase() filtering: include all non-stack (output)
  equations, not just those referenced by other equations

Changes to generate_dust.R:
- Add generate_dust_system_adjoint_rhs(): emits C++ adjoint_rhs()
  static method computing J^T * lambda (no negation - backward time
  convention handles sign)
- Include adjoint_rhs in the generated adjoint methods block

All 254 existing odin2 tests pass (0 failures).
End-to-end HMC sampling verified: beta=0.5002 (true=0.5),
gamma=0.0962 (true=0.1).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Enable gradient computation for models with array state variables,
array intermediates, and 2D arrays. Changes span the parse and
code generation stages:

parse_adjoint.R:
- Remove 'Implement differentiation with arrays' blocker
- Build adjoint array metadata: adj_X inherits dims from X
- Handle indexed adjoint references (adj_S[i], adj_S[i,j])
- Add reduce_loops for scalar-from-array accumulation
- Unwrap OdinReduce to sum(X[]) for monty differentiator
- Store array intermediate adjoints in internal (not stack)
- Exclude dim_* equations from adjoint intermediate lists
- Preserve state-before-parameter packing order

generate_dust.R:
- Fix adjoint array LHS indexing with expr_plus(idx, offset)
- Handle reduce_loops with nested for loops and += accumulation
- Strip const declaration for stack intermediates in reduce_loops
- Add adjoint_split_accumulate() helper

Tested: 1D arrays (3-group, 5-group), 2D arrays (age x vax),
discrete-time arrays, array parameters. Gradient matches finite
differences to ~1e-7 (continuous) / ~1e-9 (discrete).
All 1083 existing tests pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…lex models

- Add digamma to FUNCTIONS_MONTY_MATH in constants.R
- Add adjoint_rewrap_sums() to convert bare sum() calls back to OdinReduce
- Add adjoint_make_reduce() helper for proper OdinReduce construction
- Handle partial index sums with OdinDim range entries
- All 1083 tests pass

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Two fixes for complex array adjoint models:

1. Recursive dependency resolution in adjoint_phase(): the previous
   one-level-deep scan missed state variables referenced transitively
   through intermediate equations (e.g., D used by target_met_adults_t
   used by adjoint equations). Now iterates until convergence.

2. Free index variable reduction in generate_dust_assignment(): when a
   lower-dimensional adjoint variable depends on higher-dimensional
   arrays (e.g., adj_prop_infectious[i] from s_ij_sex[i,j]), the
   extra index variables need reduction loops. Added
   adjoint_find_free_indices() to detect free indices and generate
   zero-init + accumulation pattern.

mpoxseir (1147-line discrete-time model, 47+ parameters, 2D arrays)
now compiles with adjoint support. All 1083 tests pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…hain

Two bugs in the array adjoint code generation:

1. Kronecker delta free-index resolution: when s_ij[i,j] = m[i,j] * prop_I[j],
   the adjoint for prop_I produced adj_prop_I[i] = sum_j(adj_s_ij[i,j] * m[i,j] * delta(j,i))
   which only captures diagonal contributions. Fixed by detecting delta patterns,
   removing them, and swapping subscript index variables so the reduction covers
   all off-diagonal contributions correctly.

2. Shared intermediate adjoint chain: variables like p_IR = 1 - exp(-gamma) stored
   in shared state were excluded from adjoint equation generation, breaking the
   gradient chain adj_n_IR -> adj_p_IR -> adj_gamma. Fixed by finding shared
   intermediates on the dependency path from update equations to differentiated
   parameters and including them in adjoint generation.

Also adds test-adjoint-stan.R with 24 tests comparing odin symbolic adjoint
against Stan Math reverse-mode AD and finite differences for scalar SIR,
3-group mixing SIR, and 2D (age x vaccination) SIR models.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…tting

Three fixes for array parameter adjoint code generation:

1. Array parameter adjoint packing (parse_adjoint.R): adj_beta, adj_gamma,
   adj_I0 were packed as scalars (rank 0) instead of arrays matching their
   original parameter dimensions. Added array table entries for parameter
   adjoints so they get correct rank, dims, and alias information.

2. Mixed free-index/direct term splitting (generate_dust.R): When an adjoint
   expression combines direct terms (adj_I[i]*(1-gamma[i]*dt)) with
   free-index terms (sum_j adj_s_ij[j,i]*beta[j,i]/N_pop[i]), the code
   generator now splits them: direct terms go outside the reduction loop,
   free-index terms go inside with Kronecker delta resolution.

3. Helper functions: adjoint_split_additive(), adjoint_expr_contains_free_index(),
   adjoint_expr_mentions(), adjoint_reconstruct_sum() for expression splitting.

Verified: 5-group age-structured SIR with 35 array parameters matches
Stan Math reverse-mode AD to relative error 1e-14.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant