in pyro/contrib/epidemiology/compartmental.py [0:0]
def predict(self, forecast=0):
"""
Predict latent variables and optionally forecast forward.
This may be run only after :meth:`fit_mcmc` and draws the same
``num_samples`` as passed to :meth:`fit_mcmc`.
:param int forecast: The number of time steps to forecast forward.
:returns: A dictionary mapping sample site name (or compartment name)
to a tensor whose first dimension corresponds to sample batching.
:rtype: dict
"""
if self.num_quant_bins > 1:
_require_double_precision()
if not self.samples:
raise RuntimeError("Missing samples, try running .fit_mcmc() first")
samples = self.samples
num_samples = len(next(iter(samples.values())))
particle_plate = pyro.plate("particles", num_samples,
dim=-1 - self.max_plate_nesting)
# Sample discrete auxiliary variables conditioned on the continuous
# variables sampled by _quantized_model. This samples only time steps
# [0:duration]. Here infer_discrete runs a forward-filter
# backward-sample algorithm.
logger.info("Predicting latent variables for {} time steps..."
.format(self.duration))
model = self._sequential_model
model = poutine.condition(model, samples)
model = particle_plate(model)
if not self.relaxed:
model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting)
trace = poutine.trace(model).get_trace()
samples = OrderedDict((name, site["value"].expand(site["fn"].shape()))
for name, site in trace.nodes.items()
if site["type"] == "sample"
if not site_is_subsample(site)
if not site_is_factor(site))
assert all(v.size(0) == num_samples for v in samples.values()), \
{k: tuple(v.shape) for k, v in samples.items()}
# Optionally forecast with the forward _generative_model. This samples
# time steps [duration:duration+forecast].
if forecast:
logger.info("Forecasting {} steps ahead...".format(forecast))
model = self._generative_model
model = poutine.condition(model, samples)
model = particle_plate(model)
trace = poutine.trace(model).get_trace(forecast)
samples = OrderedDict((name, site["value"])
for name, site in trace.nodes.items()
if site["type"] == "sample"
if not site_is_subsample(site)
if not site_is_factor(site))
self._concat_series(samples, trace, forecast)
assert all(v.size(0) == num_samples for v in samples.values()), \
{k: tuple(v.shape) for k, v in samples.items()}
return samples