in pyro/contrib/funsor/handlers/enum_messenger.py [0:0]
def _enum_strategy_mixture(dist, msg):
sample_dim_name = "{}__PARTICLES".format(msg["name"])
sample_inputs = OrderedDict({sample_dim_name: funsor.Bint[msg['infer']['num_samples']]})
plate_names = frozenset(f.name for f in msg["cond_indep_stack"] if f.vectorized)
ancestor_names = frozenset(k for k, v in dist.inputs.items() if v.dtype != 'real'
and k != msg["name"] and k not in plate_names)
plate_inputs = OrderedDict((k, dist.inputs[k]) for k in plate_names)
# TODO should the ancestor_indices be pyro.sampled?
ancestor_indices = {
# TODO make this comprehension less gross
name: _get_support_value(funsor.torch.distributions.CategoricalLogits(
# sample different ancestors for each plate slice
logits=funsor.Tensor(
# TODO avoid use of torch.zeros here in favor of funsor.ops.new_zeros
torch.zeros((1,)).expand(tuple(v.dtype for v in plate_inputs.values()) + (dist.inputs[name].dtype,)),
plate_inputs
),
)(value=name).sample(name, sample_inputs), name)
for name in ancestor_names
}
sampled_dist = dist(**ancestor_indices).sample(
msg["name"], sample_inputs if not ancestor_indices else None)
if ancestor_indices: # XXX is there a better way to account for this in funsor?
sampled_dist = sampled_dist - math.log(msg["infer"]["num_samples"])
return sampled_dist