in pyro/infer/traceenum_elbo.py [0:0]
def _compute_model_factors(model_trace, guide_trace):
# y depends on x iff ordering[x] <= ordering[y]
# TODO refine this coarse dependency ordering using time.
ordering = {name: _find_ordinal(trace, site)
for trace in (model_trace, guide_trace)
for name, site in trace.nodes.items()
if site["type"] == "sample"}
# Collect model sites that may have been enumerated in the model.
cost_sites = OrderedDict()
enum_sites = OrderedDict()
enum_dims = set()
non_enum_dims = set().union(*ordering.values())
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
if name in guide_trace.nodes:
cost_sites.setdefault(ordering[name], []).append(site)
non_enum_dims.update(guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims)
elif site["infer"].get("_enumerate_dim") is None:
cost_sites.setdefault(ordering[name], []).append(site)
else:
enum_sites.setdefault(ordering[name], []).append(site)
enum_dims.update(site["packed"]["log_prob"]._pyro_dims)
enum_dims -= non_enum_dims
log_factors = OrderedDict()
scale = 1
if not enum_sites:
marginal_costs = OrderedDict((t, [site["packed"]["log_prob"] for site in sites_t])
for t, sites_t in cost_sites.items())
return marginal_costs, log_factors, ordering, enum_dims, scale
_check_model_guide_enumeration_constraint(enum_sites, guide_trace)
# Marginalize out all variables that have been enumerated in the model.
marginal_costs = OrderedDict()
scales = []
for t, sites_t in cost_sites.items():
for site in sites_t:
if enum_dims.isdisjoint(site["packed"]["log_prob"]._pyro_dims):
# For sites that do not depend on an enumerated variable, proceed as usual.
marginal_costs.setdefault(t, []).append(site["packed"]["log_prob"])
else:
# For sites that depend on an enumerated variable, we need to apply
# the mask inside- and the scale outside- of the log expectation.
if "masked_log_prob" not in site["packed"]:
site["packed"]["masked_log_prob"] = packed.scale_and_mask(
site["packed"]["unscaled_log_prob"], mask=site["packed"]["mask"])
cost = site["packed"]["masked_log_prob"]
log_factors.setdefault(t, []).append(cost)
scales.append(site["scale"])
if log_factors:
for t, sites_t in enum_sites.items():
# TODO refine this coarse dependency ordering using time and tensor shapes.
if any(t <= u for u in log_factors):
for site in sites_t:
logprob = site["packed"]["unscaled_log_prob"]
log_factors.setdefault(t, []).append(logprob)
scales.append(site["scale"])
scale = _get_common_scale(scales)
return marginal_costs, log_factors, ordering, enum_dims, scale