Skip to content

NUTS Sampler for SIR code prohibitively slow #1711

@njs59

Description

@njs59

I seem to be getting an error when using NUTS sampler to fir to an SIR model
Samplers that do not require sensitivities run well but any that require sensitivities do not

import numpy as np, pints
from scipy.integrate import solve_ivp

class SIR(pints.ForwardModelS1):
    """SIR with sensitivities wrt [beta, gamma]. Output: I(t)."""
    def __init__(self, y0=(760., 3., 0.)):
        self.y0 = np.array(y0, float); self.N = float(sum(y0))
    def n_outputs(self): return 1
    def n_parameters(self): return 2

    # Base RHS: [S, I, R]
    def _rhs(self, t, y, b, g):
        S, I, R = y
        SI = S * I / self.N
        return np.array([-b*SI, b*SI - g*I, g*I], float)

    # Sensitivity RHS: [S, I, R, SB, IB, SG, IG]  (RB/RG not needed)
    def _rhs_sens(self, t, z, b, g):
        S, I, R, SB, IB, SG, IG = z
        invN = 1.0 / self.N; SI = S * I * invN
        dS = -b * SI
        dI =  b * SI - g * I
        dR =  g * I
        a = -b * I * invN; bS = -b * S * invN
        c =  b * I * invN; d =  b * S * invN - g
        # d/d(beta)
        dSB = a*SB + bS*IB + (-SI)
        dIB = c*SB + d*IB + ( SI)
        # d/d(gamma)
        dSG = a*SG + bS*IG + 0.0
        dIG = c*SG + d*IG - I
        return np.array([dS, dI, dR, dSB, dIB, dSG, dIG], float)

    def simulate(self, x, t):
        b, g = map(float, x); t = np.asarray(t, float)
        order = np.argsort(t); ts = t[order]; y0 = self.y0
        sol = solve_ivp(lambda tt, yy: self._rhs(tt, yy, b, g),
                        (float(ts[0]), float(ts[-1])), y0, t_eval=ts,
                        rtol=1e-6, atol=1e-6, method='RK45')
        if not sol.success: raise RuntimeError(sol.message)
        out = np.empty_like(t); out[order] = sol.y[1]; return out

    def simulateS1(self, x, t):
        b, g = map(float, x); t = np.asarray(t, float)
        order = np.argsort(t); ts = t[order]
        z0 = np.zeros(7); z0[:3] = self.y0
        sol = solve_ivp(lambda tt, zz: self._rhs_sens(tt, zz, b, g),
                        (float(ts[0]), float(ts[-1])), z0, t_eval=ts,
                        rtol=1e-6, atol=1e-6, method='RK45')
        if not sol.success: raise RuntimeError(sol.message)
        Z = sol.y.T
        I, dIb, dIg = Z[:,1], Z[:,4], Z[:,6]
        y = np.empty_like(t); dy = np.empty((len(t),2))
        y[order] = I; dy[order,0] = dIb; dy[order,1] = dIg
        return y, dy

# Tiny fixed dataset (14 points)
times  = np.linspace(1, 14, 14)
values = np.array([3, 8, 28, 75, 221, 291, 255, 235, 190, 126, 70, 28, 12, 5], float)

# Build PINTS problem & log-posterior (adds sigma as 3rd parameter)
model   = SIR(y0=(999., 1., 0.))
problem = pints.SingleOutputProblem(model, times, values)
like    = pints.GaussianLogLikelihood(problem)
prior   = pints.ComposedLogPrior(
    pints.LogUniformLogPrior(1e-5, 1e-2),  # beta
    pints.LogUniformLogPrior(1e-3, 1.0),   # gamma
    pints.LogUniformLogPrior(1e-4, 1.0),   # sigma
)
post = pints.LogPosterior(like, prior)

# NUTS (gradient-based), short run, single chain
x0 = [3e-4, 1e-1, 1e-2]
mcmc = pints.MCMCController(post, 1, [x0], method=pints.NoUTurnMCMC)
mcmc.set_max_iterations(5000); mcmc.set_log_to_screen(True)
for s in mcmc.samplers():
    s.set_max_tree_depth(6)
    s.set_delta(0.75)
chains = mcmc.run()
print("Chains shape:", chains.shape)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions