in pyro/contrib/epidemiology/compartmental.py [0:0]
def fit_mcmc(self, **options):
r"""
Runs NUTS inference to generate posterior samples.
This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run
:class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples``
attribute on completion.
This uses an asymptotically exact enumeration-based model when
``num_quant_bins > 1``, and a cheaper moment-matched approximate model
when ``num_quant_bins == 1``.
:param \*\*options: Options passed to
:class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are
pulled out and have special meaning.
:param int num_samples: Number of posterior samples to draw via mcmc.
Defaults to 100.
:param int max_tree_depth: (Default 5). Max tree depth of the
:class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
:param full_mass: Specification of mass matrix of the
:class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass
over global random variables.
:param bool arrowhead_mass: Whether to treat ``full_mass`` as the head
of an arrowhead matrix versus simply as a block. Defaults to False.
:param int num_quant_bins: If greater than 1, use asymptotically exact
inference via local enumeration over this many quantization bins.
If equal to 1, use continuous-valued relaxed approximate inference.
Note that computational cost is exponential in `num_quant_bins`.
Defaults to 1 for relaxed inference.
:param bool haar: Whether to use a Haar wavelet reparameterizer.
Defaults to True.
:param int haar_full_mass: Number of low frequency Haar components to
include in the full mass matrix. If ``haar=False`` then this is
ignored. Defaults to 10.
:param int heuristic_num_particles: Passed to :meth:`heuristic` as
``num_particles``. Defaults to 1024.
:returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
:rtype: ~pyro.infer.mcmc.api.MCMC
"""
_require_double_precision()
# Parse options, saving some for use in .predict().
num_samples = options.setdefault("num_samples", 100)
num_chains = options.setdefault("num_chains", 1)
self.num_quant_bins = options.pop("num_quant_bins", 1)
assert isinstance(self.num_quant_bins, int)
assert self.num_quant_bins >= 1
self.relaxed = self.num_quant_bins == 1
# Setup Haar wavelet transform.
haar = options.pop("haar", False)
haar_full_mass = options.pop("haar_full_mass", 10)
full_mass = options.pop("full_mass", self.full_mass)
assert isinstance(haar, bool)
assert isinstance(haar_full_mass, int) and haar_full_mass >= 0
assert isinstance(full_mass, (bool, list))
haar_full_mass = min(haar_full_mass, self.duration)
if not haar:
haar_full_mass = 0
if full_mass is True:
haar_full_mass = 0 # No need to split.
elif haar_full_mass >= self.duration:
full_mass = True # Effectively full mass.
haar_full_mass = 0
if haar:
time_dim = -2 if self.is_regional else -1
dims = {"auxiliary": time_dim}
supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
for name, (fn, is_regional) in self._non_compartmental.items():
dims[name] = time_dim - fn.event_dim
supports[name] = fn.support
haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports)
if haar_full_mass:
assert full_mass and isinstance(full_mass, list)
full_mass = full_mass[:]
full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims))
# Heuristically initialize to feasible latents.
heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
for k in list(options)
if k.startswith("heuristic_")}
init_strategy = init_to_generated(
generate=functools.partial(self._heuristic, haar, **heuristic_options))
# Configure a kernel.
logger.info("Running inference...")
model = self._relaxed_model if self.relaxed else self._quantized_model
if haar:
model = haar.reparam(model)
kernel = NUTS(model,
full_mass=full_mass,
init_strategy=init_strategy,
max_plate_nesting=self.max_plate_nesting,
jit_compile=options.pop("jit_compile", False),
jit_options=options.pop("jit_options", None),
ignore_jit_warnings=options.pop("ignore_jit_warnings", True),
target_accept_prob=options.pop("target_accept_prob", 0.8),
max_tree_depth=options.pop("max_tree_depth", 5))
if options.pop("arrowhead_mass", False):
kernel.mass_matrix_adapter = ArrowheadMassMatrix()
# Run mcmc.
options.setdefault("disable_validation", None)
mcmc = MCMC(kernel, **options)
mcmc.run()
self.samples = mcmc.get_samples()
if haar:
haar.aux_to_user(self.samples)
# Unsqueeze samples to align particle dim for use in poutine.condition.
# TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
model = self._relaxed_model if self.relaxed else self._quantized_model
self.samples = align_samples(self.samples, model,
particle_dim=-1 - self.max_plate_nesting)
assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \
{k: tuple(v.shape) for k, v in self.samples.items()}
return mcmc # E.g. so user can run mcmc.summary().