in pyro/contrib/epidemiology/compartmental.py [0:0]
def fit_svi(self, *,
num_samples=100,
num_steps=2000,
num_particles=32,
learning_rate=0.1,
learning_rate_decay=0.01,
betas=(0.8, 0.99),
haar=True,
init_scale=0.01,
guide_rank=0,
jit=False,
log_every=200,
**options):
"""
Runs stochastic variational inference to generate posterior samples.
This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples``
attribute on completion.
This approximate inference method is useful for quickly iterating on
probabilistic models.
:param int num_samples: Number of posterior samples to draw from the
trained guide. Defaults to 100.
:param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps.
:param int num_particles: Number of :class:`~pyro.infer.svi.SVI`
particles per step.
:param int learning_rate: Learning rate for the
:class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
:param int learning_rate_decay: Learning rate for the
:class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this
is decay over the entire schedule, not per-step decay.
:param tuple betas: Momentum parameters for the
:class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
:param bool haar: Whether to use a Haar wavelet reparameterizer.
:param int guide_rank: Rank of the auto normal guide. If zero (default)
use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a
positive integer or None, use an
:class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
If the string "full", use an
:class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These
latter two require more ``num_steps`` to fit.
:param float init_scale: Initial scale of the
:class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
:param bool jit: Whether to use a jit compiled ELBO.
:param int log_every: How often to log svi losses.
:param int heuristic_num_particles: Passed to :meth:`heuristic` as
``num_particles``. Defaults to 1024.
:returns: Time series of SVI losses (useful to diagnose convergence).
:rtype: list
"""
# Save configuration for .predict().
self.relaxed = True
self.num_quant_bins = 1
# Setup Haar wavelet transform.
if haar:
time_dim = -2 if self.is_regional else -1
dims = {"auxiliary": time_dim}
supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
for name, (fn, is_regional) in self._non_compartmental.items():
dims[name] = time_dim - fn.event_dim
supports[name] = fn.support
haar = _HaarSplitReparam(0, self.duration, dims, supports)
# Heuristically initialize to feasible latents.
heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
for k in list(options)
if k.startswith("heuristic_")}
assert not options, "unrecognized options: {}".format(", ".join(options))
init_strategy = self._heuristic(haar, **heuristic_options)
# Configure variational inference.
logger.info("Running inference...")
model = self._relaxed_model
if haar:
model = haar.reparam(model)
if guide_rank == 0:
guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale)
elif guide_rank == "full":
guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy,
init_scale=init_scale)
elif guide_rank is None or isinstance(guide_rank, int):
guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy,
init_scale=init_scale, rank=guide_rank)
else:
raise ValueError("Invalid guide_rank: {}".format(guide_rank))
Elbo = JitTrace_ELBO if jit else Trace_ELBO
elbo = Elbo(max_plate_nesting=self.max_plate_nesting,
num_particles=num_particles, vectorize_particles=True,
ignore_jit_warnings=True)
optim = ClippedAdam({"lr": learning_rate, "betas": betas,
"lrd": learning_rate_decay ** (1 / num_steps)})
svi = SVI(model, guide, optim, elbo)
# Run inference.
start_time = default_timer()
losses = []
for step in range(1 + num_steps):
loss = svi.step() / self.duration
if step % log_every == 0:
logger.info("step {} loss = {:0.4g}".format(step, loss))
losses.append(loss)
elapsed = default_timer() - start_time
logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec"
.format(elapsed, (1 + num_steps) / elapsed))
# Draw posterior samples.
with torch.no_grad():
particle_plate = pyro.plate("particles", num_samples,
dim=-1 - self.max_plate_nesting)
guide_trace = poutine.trace(particle_plate(guide)).get_trace()
model_trace = poutine.trace(
poutine.replay(particle_plate(model), guide_trace)).get_trace()
self.samples = {name: site["value"] for name, site in model_trace.nodes.items()
if site["type"] == "sample"
if not site["is_observed"]
if not site_is_subsample(site)}
if haar:
haar.aux_to_user(self.samples)
assert all(v.size(0) == num_samples for v in self.samples.values()), \
{k: tuple(v.shape) for k, v in self.samples.items()}
return losses