in pyro/contrib/epidemiology/compartmental.py [0:0]
def _quantized_model(self):
"""
Quantized vectorized model used for parallel-scan enumerated inference.
This method is called only outside particle_plate.
"""
C = len(self.compartments)
T = self.duration
Q = self.num_quant_bins
R_shape = getattr(self.population, "shape", ()) # Region shape.
# Sample global parameters and auxiliary variables.
params = self.global_model()
auxiliary, non_compartmental = self._sample_auxiliary()
# Manually enumerate.
curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population,
num_quant_bins=self.num_quant_bins)
curr = OrderedDict(zip(self.compartments, curr.unbind(0)))
logp = OrderedDict(zip(self.compartments, logp.unbind(0)))
curr.update(non_compartmental)
# Truncate final value from the right then pad initial value onto the left.
init = self.initialize(params)
prev = {}
for name, value in init.items():
if name in self.compartments:
if isinstance(value, torch.Tensor):
value = value[..., None] # Because curr is enumerated on the right.
prev[name] = cat2(value, curr[name][:-1],
dim=-3 if self.is_regional else -2)
else: # non-compartmental
prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim())
# Reshape to support broadcasting, similar to EnumMessenger.
def enum_reshape(tensor, position):
assert tensor.size(-1) == Q
assert tensor.dim() <= self.max_plate_nesting + 2
tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1))
shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2))
shape.extend(tensor.shape[1:])
return tensor.reshape(shape)
for e, name in enumerate(self.compartments):
curr[name] = enum_reshape(curr[name], e)
logp[name] = enum_reshape(logp[name], e)
prev[name] = enum_reshape(prev[name], e + C)
# Enable approximate inference by using aux as a non-enumerated proxy
# for enumerated compartment values.
for name in self.approximate:
aux = auxiliary[self.compartments.index(name)]
curr[name + "_approx"] = aux
prev[name + "_approx"] = cat2(init[name], aux[:-1],
dim=-2 if self.is_regional else -1)
# Record transition factors.
with poutine.block(), poutine.trace() as tr:
with self.time_plate:
t = slice(0, T, 1) # Used to slice data tensors.
self._transition_bwd(params, prev, curr, t)
tr.trace.compute_log_prob()
for name, site in tr.trace.nodes.items():
if site["type"] == "sample":
log_prob = site["log_prob"]
if log_prob.dim() <= self.max_plate_nesting: # Not enumerated.
pyro.factor("transition_" + name, site["log_prob_sum"])
continue
if self.is_regional and log_prob.shape[-1:] != R_shape:
# Poor man's tensor variable elimination.
log_prob = log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0]
logp[name] = site["log_prob"]
# Manually perform variable elimination.
logp = reduce(operator.add, logp.values())
logp = logp.reshape(Q ** C, Q ** C, T, -1) # prev, curr, T, batch
logp = logp.permute(3, 2, 0, 1).squeeze(0) # batch, T, prev, curr
logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) # batch, prev, curr
logp = logp.reshape(-1, Q ** C * Q ** C).logsumexp(-1).sum()
warn_if_nan(logp)
pyro.factor("transition", logp)
self._clear_plates()