in examples/contrib/epidemiology/sir.py [0:0]
def predict(args, model, truth):
samples = model.predict(forecast=args.forecast)
obs = model.data
new_I = samples.get("S2I", samples.get("E2I"))
median = new_I.median(dim=0).values
logging.info("Median prediction of new infections (starting on day 0):\n{}"
.format(" ".join(map(str, map(int, median)))))
# Optionally plot the latent and forecasted series of new infections.
if args.plot:
import matplotlib.pyplot as plt
plt.figure()
time = torch.arange(args.duration + args.forecast)
p05 = new_I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
p95 = new_I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
plt.plot(time, median, "r-", label="median")
plt.plot(time[:args.duration], obs, "k.", label="observed")
if truth is not None:
plt.plot(time, truth, "k--", label="truth")
plt.axvline(args.duration - 0.5, color="gray", lw=1)
plt.xlim(0, len(time) - 1)
plt.ylim(0, None)
plt.xlabel("day after first infection")
plt.ylabel("new infections per day")
plt.title("New infections in population of {}".format(args.population))
plt.legend(loc="upper left")
plt.tight_layout()
# Plot Re time series.
if args.heterogeneous:
plt.figure()
Re = samples["Re"]
median = Re.median(dim=0).values
p05 = Re.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
p95 = Re.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
plt.plot(time, median, "r-", label="median")
plt.plot(time[:args.duration], obs, "k.", label="observed")
plt.axvline(args.duration - 0.5, color="gray", lw=1)
plt.xlim(0, len(time) - 1)
plt.ylim(0, None)
plt.xlabel("day after first infection")
plt.ylabel("Re")
plt.title("Effective reproductive number over time")
plt.legend(loc="upper left")
plt.tight_layout()