def fit_svi()

in pyro/contrib/epidemiology/compartmental.py [0:0]


    def fit_svi(self, *,
                num_samples=100,
                num_steps=2000,
                num_particles=32,
                learning_rate=0.1,
                learning_rate_decay=0.01,
                betas=(0.8, 0.99),
                haar=True,
                init_scale=0.01,
                guide_rank=0,
                jit=False,
                log_every=200,
                **options):
        """
        Runs stochastic variational inference to generate posterior samples.

        This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples``
        attribute on completion.

        This approximate inference method is useful for quickly iterating on
        probabilistic models.

        :param int num_samples: Number of posterior samples to draw from the
            trained guide. Defaults to 100.
        :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps.
        :param int num_particles: Number of :class:`~pyro.infer.svi.SVI`
            particles per step.
        :param int learning_rate: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param int learning_rate_decay: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this
            is decay over the entire schedule, not per-step decay.
        :param tuple betas: Momentum parameters for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
        :param int guide_rank: Rank of the auto normal guide. If zero (default)
            use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a
            positive integer or None, use an
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
            If the string "full", use an
            :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These
            latter two require more ``num_steps`` to fit.
        :param float init_scale: Initial scale of the
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
        :param bool jit: Whether to use a jit compiled ELBO.
        :param int log_every: How often to log svi losses.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: Time series of SVI losses (useful to diagnose convergence).
        :rtype: list
        """
        # Save configuration for .predict().
        self.relaxed = True
        self.num_quant_bins = 1

        # Setup Haar wavelet transform.
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(0, self.duration, dims, supports)

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        assert not options, "unrecognized options: {}".format(", ".join(options))
        init_strategy = self._heuristic(haar, **heuristic_options)

        # Configure variational inference.
        logger.info("Running inference...")
        model = self._relaxed_model
        if haar:
            model = haar.reparam(model)
        if guide_rank == 0:
            guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale)
        elif guide_rank == "full":
            guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy,
                                           init_scale=init_scale)
        elif guide_rank is None or isinstance(guide_rank, int):
            guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy,
                                                  init_scale=init_scale, rank=guide_rank)
        else:
            raise ValueError("Invalid guide_rank: {}".format(guide_rank))
        Elbo = JitTrace_ELBO if jit else Trace_ELBO
        elbo = Elbo(max_plate_nesting=self.max_plate_nesting,
                    num_particles=num_particles, vectorize_particles=True,
                    ignore_jit_warnings=True)
        optim = ClippedAdam({"lr": learning_rate, "betas": betas,
                             "lrd": learning_rate_decay ** (1 / num_steps)})
        svi = SVI(model, guide, optim, elbo)

        # Run inference.
        start_time = default_timer()
        losses = []
        for step in range(1 + num_steps):
            loss = svi.step() / self.duration
            if step % log_every == 0:
                logger.info("step {} loss = {:0.4g}".format(step, loss))
            losses.append(loss)
        elapsed = default_timer() - start_time
        logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec"
                    .format(elapsed, (1 + num_steps) / elapsed))

        # Draw posterior samples.
        with torch.no_grad():
            particle_plate = pyro.plate("particles", num_samples,
                                        dim=-1 - self.max_plate_nesting)
            guide_trace = poutine.trace(particle_plate(guide)).get_trace()
            model_trace = poutine.trace(
                poutine.replay(particle_plate(model), guide_trace)).get_trace()
            self.samples = {name: site["value"] for name, site in model_trace.nodes.items()
                            if site["type"] == "sample"
                            if not site["is_observed"]
                            if not site_is_subsample(site)}
            if haar:
                haar.aux_to_user(self.samples)
        assert all(v.size(0) == num_samples for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return losses