def predict()

in pyro/contrib/epidemiology/compartmental.py [0:0]


    def predict(self, forecast=0):
        """
        Predict latent variables and optionally forecast forward.

        This may be run only after :meth:`fit_mcmc` and draws the same
        ``num_samples`` as passed to :meth:`fit_mcmc`.

        :param int forecast: The number of time steps to forecast forward.
        :returns: A dictionary mapping sample site name (or compartment name)
            to a tensor whose first dimension corresponds to sample batching.
        :rtype: dict
        """
        if self.num_quant_bins > 1:
            _require_double_precision()
        if not self.samples:
            raise RuntimeError("Missing samples, try running .fit_mcmc() first")

        samples = self.samples
        num_samples = len(next(iter(samples.values())))
        particle_plate = pyro.plate("particles", num_samples,
                                    dim=-1 - self.max_plate_nesting)

        # Sample discrete auxiliary variables conditioned on the continuous
        # variables sampled by _quantized_model. This samples only time steps
        # [0:duration]. Here infer_discrete runs a forward-filter
        # backward-sample algorithm.
        logger.info("Predicting latent variables for {} time steps..."
                    .format(self.duration))
        model = self._sequential_model
        model = poutine.condition(model, samples)
        model = particle_plate(model)
        if not self.relaxed:
            model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting)
        trace = poutine.trace(model).get_trace()
        samples = OrderedDict((name, site["value"].expand(site["fn"].shape()))
                              for name, site in trace.nodes.items()
                              if site["type"] == "sample"
                              if not site_is_subsample(site)
                              if not site_is_factor(site))
        assert all(v.size(0) == num_samples for v in samples.values()), \
            {k: tuple(v.shape) for k, v in samples.items()}

        # Optionally forecast with the forward _generative_model. This samples
        # time steps [duration:duration+forecast].
        if forecast:
            logger.info("Forecasting {} steps ahead...".format(forecast))
            model = self._generative_model
            model = poutine.condition(model, samples)
            model = particle_plate(model)
            trace = poutine.trace(model).get_trace(forecast)
            samples = OrderedDict((name, site["value"])
                                  for name, site in trace.nodes.items()
                                  if site["type"] == "sample"
                                  if not site_is_subsample(site)
                                  if not site_is_factor(site))

        self._concat_series(samples, trace, forecast)
        assert all(v.size(0) == num_samples for v in samples.values()), \
            {k: tuple(v.shape) for k, v in samples.items()}
        return samples