def _enum_strategy_mixture()

in pyro/contrib/funsor/handlers/enum_messenger.py [0:0]


def _enum_strategy_mixture(dist, msg):
    sample_dim_name = "{}__PARTICLES".format(msg["name"])
    sample_inputs = OrderedDict({sample_dim_name: funsor.Bint[msg['infer']['num_samples']]})
    plate_names = frozenset(f.name for f in msg["cond_indep_stack"] if f.vectorized)
    ancestor_names = frozenset(k for k, v in dist.inputs.items() if v.dtype != 'real'
                               and k != msg["name"] and k not in plate_names)
    plate_inputs = OrderedDict((k, dist.inputs[k]) for k in plate_names)
    # TODO should the ancestor_indices be pyro.sampled?
    ancestor_indices = {
        # TODO make this comprehension less gross
        name: _get_support_value(funsor.torch.distributions.CategoricalLogits(
            # sample different ancestors for each plate slice
            logits=funsor.Tensor(
                # TODO avoid use of torch.zeros here in favor of funsor.ops.new_zeros
                torch.zeros((1,)).expand(tuple(v.dtype for v in plate_inputs.values()) + (dist.inputs[name].dtype,)),
                plate_inputs
            ),
        )(value=name).sample(name, sample_inputs), name)
        for name in ancestor_names
    }
    sampled_dist = dist(**ancestor_indices).sample(
        msg["name"], sample_inputs if not ancestor_indices else None)
    if ancestor_indices:  # XXX is there a better way to account for this in funsor?
        sampled_dist = sampled_dist - math.log(msg["infer"]["num_samples"])
    return sampled_dist