in examples/sir_hmc.py [0:0]
def predict(args, data, samples, truth=None):
logging.info("Forecasting {} steps ahead...".format(args.forecast))
particle_plate = pyro.plate("particles", args.num_samples, dim=-1)
# First we sample discrete auxiliary variables from the continuous
# variables sampled in vectorized_model. This samples only time steps
# [0:duration]. Here infer_discrete runs a forward-filter backward-sample
# algorithm. We'll add these new samples to the existing dict of samples.
model = poutine.condition(continuous_model, samples)
model = particle_plate(model)
model = infer_discrete(model, first_available_dim=-2)
with poutine.trace() as tr:
model(args, data)
samples = OrderedDict((name, site["value"])
for name, site in tr.trace.nodes.items()
if site["type"] == "sample")
# Next we'll run the forward generative process in discrete_model. This
# samples time steps [duration:duration+forecast]. Again we'll update the
# dict of samples.
extended_data = list(data) + [None] * args.forecast
model = poutine.condition(discrete_model, samples)
model = particle_plate(model)
with poutine.trace() as tr:
model(args, extended_data)
samples = OrderedDict((name, site["value"])
for name, site in tr.trace.nodes.items()
if site["type"] == "sample")
# Finally we'll concatenate the sequentially sampled values into contiguous
# tensors. This operates on the entire time interval [0:duration+forecast].
for key in ("S", "I", "S2I", "I2R"):
pattern = key + "_[0-9]+"
series = [value
for name, value in samples.items()
if re.match(pattern, name)]
assert len(series) == args.duration + args.forecast
series[0] = series[0].expand(series[1].shape)
samples[key] = torch.stack(series, dim=-1)
S2I = samples["S2I"]
median = S2I.median(dim=0).values
logging.info("Median prediction of new infections (starting on day 0):\n{}"
.format(" ".join(map(str, map(int, median)))))
# Optionally plot the latent and forecasted series of new infections.
if args.plot:
import matplotlib.pyplot as plt
plt.figure()
time = torch.arange(args.duration + args.forecast)
p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
plt.plot(time, median, "r-", label="median")
plt.plot(time[:args.duration], data, "k.", label="observed")
if truth is not None:
plt.plot(time, truth, "k--", label="truth")
plt.axvline(args.duration - 0.5, color="gray", lw=1)
plt.xlim(0, len(time) - 1)
plt.ylim(0, None)
plt.xlabel("day after first infection")
plt.ylabel("new infections per day")
plt.title("New infections in population of {}".format(args.population))
plt.legend(loc="upper left")
plt.tight_layout()
return samples