def _quantized_model()

in pyro/contrib/epidemiology/compartmental.py [0:0]


    def _quantized_model(self):
        """
        Quantized vectorized model used for parallel-scan enumerated inference.
        This method is called only outside particle_plate.
        """
        C = len(self.compartments)
        T = self.duration
        Q = self.num_quant_bins
        R_shape = getattr(self.population, "shape", ())  # Region shape.

        # Sample global parameters and auxiliary variables.
        params = self.global_model()
        auxiliary, non_compartmental = self._sample_auxiliary()

        # Manually enumerate.
        curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population,
                                        num_quant_bins=self.num_quant_bins)
        curr = OrderedDict(zip(self.compartments, curr.unbind(0)))
        logp = OrderedDict(zip(self.compartments, logp.unbind(0)))
        curr.update(non_compartmental)

        # Truncate final value from the right then pad initial value onto the left.
        init = self.initialize(params)
        prev = {}
        for name, value in init.items():
            if name in self.compartments:
                if isinstance(value, torch.Tensor):
                    value = value[..., None]  # Because curr is enumerated on the right.
                prev[name] = cat2(value, curr[name][:-1],
                                  dim=-3 if self.is_regional else -2)
            else:  # non-compartmental
                prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim())

        # Reshape to support broadcasting, similar to EnumMessenger.
        def enum_reshape(tensor, position):
            assert tensor.size(-1) == Q
            assert tensor.dim() <= self.max_plate_nesting + 2
            tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1))
            shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2))
            shape.extend(tensor.shape[1:])
            return tensor.reshape(shape)

        for e, name in enumerate(self.compartments):
            curr[name] = enum_reshape(curr[name], e)
            logp[name] = enum_reshape(logp[name], e)
            prev[name] = enum_reshape(prev[name], e + C)

        # Enable approximate inference by using aux as a non-enumerated proxy
        # for enumerated compartment values.
        for name in self.approximate:
            aux = auxiliary[self.compartments.index(name)]
            curr[name + "_approx"] = aux
            prev[name + "_approx"] = cat2(init[name], aux[:-1],
                                          dim=-2 if self.is_regional else -1)

        # Record transition factors.
        with poutine.block(), poutine.trace() as tr:
            with self.time_plate:
                t = slice(0, T, 1)  # Used to slice data tensors.
                self._transition_bwd(params, prev, curr, t)
        tr.trace.compute_log_prob()
        for name, site in tr.trace.nodes.items():
            if site["type"] == "sample":
                log_prob = site["log_prob"]
                if log_prob.dim() <= self.max_plate_nesting:  # Not enumerated.
                    pyro.factor("transition_" + name, site["log_prob_sum"])
                    continue
                if self.is_regional and log_prob.shape[-1:] != R_shape:
                    # Poor man's tensor variable elimination.
                    log_prob = log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0]
                logp[name] = site["log_prob"]

        # Manually perform variable elimination.
        logp = reduce(operator.add, logp.values())
        logp = logp.reshape(Q ** C, Q ** C, T, -1)  # prev, curr, T, batch
        logp = logp.permute(3, 2, 0, 1).squeeze(0)  # batch, T, prev, curr
        logp = pyro.distributions.hmm._sequential_logmatmulexp(logp)  # batch, prev, curr
        logp = logp.reshape(-1, Q ** C * Q ** C).logsumexp(-1).sum()
        warn_if_nan(logp)
        pyro.factor("transition", logp)

        self._clear_plates()