-
Notifications
You must be signed in to change notification settings - Fork 34
Open
Description
I seem to be getting an error when using NUTS sampler to fir to an SIR model
Samplers that do not require sensitivities run well but any that require sensitivities do not
import numpy as np, pints
from scipy.integrate import solve_ivp
class SIR(pints.ForwardModelS1):
"""SIR with sensitivities wrt [beta, gamma]. Output: I(t)."""
def __init__(self, y0=(760., 3., 0.)):
self.y0 = np.array(y0, float); self.N = float(sum(y0))
def n_outputs(self): return 1
def n_parameters(self): return 2
# Base RHS: [S, I, R]
def _rhs(self, t, y, b, g):
S, I, R = y
SI = S * I / self.N
return np.array([-b*SI, b*SI - g*I, g*I], float)
# Sensitivity RHS: [S, I, R, SB, IB, SG, IG] (RB/RG not needed)
def _rhs_sens(self, t, z, b, g):
S, I, R, SB, IB, SG, IG = z
invN = 1.0 / self.N; SI = S * I * invN
dS = -b * SI
dI = b * SI - g * I
dR = g * I
a = -b * I * invN; bS = -b * S * invN
c = b * I * invN; d = b * S * invN - g
# d/d(beta)
dSB = a*SB + bS*IB + (-SI)
dIB = c*SB + d*IB + ( SI)
# d/d(gamma)
dSG = a*SG + bS*IG + 0.0
dIG = c*SG + d*IG - I
return np.array([dS, dI, dR, dSB, dIB, dSG, dIG], float)
def simulate(self, x, t):
b, g = map(float, x); t = np.asarray(t, float)
order = np.argsort(t); ts = t[order]; y0 = self.y0
sol = solve_ivp(lambda tt, yy: self._rhs(tt, yy, b, g),
(float(ts[0]), float(ts[-1])), y0, t_eval=ts,
rtol=1e-6, atol=1e-6, method='RK45')
if not sol.success: raise RuntimeError(sol.message)
out = np.empty_like(t); out[order] = sol.y[1]; return out
def simulateS1(self, x, t):
b, g = map(float, x); t = np.asarray(t, float)
order = np.argsort(t); ts = t[order]
z0 = np.zeros(7); z0[:3] = self.y0
sol = solve_ivp(lambda tt, zz: self._rhs_sens(tt, zz, b, g),
(float(ts[0]), float(ts[-1])), z0, t_eval=ts,
rtol=1e-6, atol=1e-6, method='RK45')
if not sol.success: raise RuntimeError(sol.message)
Z = sol.y.T
I, dIb, dIg = Z[:,1], Z[:,4], Z[:,6]
y = np.empty_like(t); dy = np.empty((len(t),2))
y[order] = I; dy[order,0] = dIb; dy[order,1] = dIg
return y, dy
# Tiny fixed dataset (14 points)
times = np.linspace(1, 14, 14)
values = np.array([3, 8, 28, 75, 221, 291, 255, 235, 190, 126, 70, 28, 12, 5], float)
# Build PINTS problem & log-posterior (adds sigma as 3rd parameter)
model = SIR(y0=(999., 1., 0.))
problem = pints.SingleOutputProblem(model, times, values)
like = pints.GaussianLogLikelihood(problem)
prior = pints.ComposedLogPrior(
pints.LogUniformLogPrior(1e-5, 1e-2), # beta
pints.LogUniformLogPrior(1e-3, 1.0), # gamma
pints.LogUniformLogPrior(1e-4, 1.0), # sigma
)
post = pints.LogPosterior(like, prior)
# NUTS (gradient-based), short run, single chain
x0 = [3e-4, 1e-1, 1e-2]
mcmc = pints.MCMCController(post, 1, [x0], method=pints.NoUTurnMCMC)
mcmc.set_max_iterations(5000); mcmc.set_log_to_screen(True)
for s in mcmc.samplers():
s.set_max_tree_depth(6)
s.set_delta(0.75)
chains = mcmc.run()
print("Chains shape:", chains.shape)Metadata
Metadata
Assignees
Labels
No labels