in pyro/contrib/funsor/infer/traceenum_elbo.py [0:0]
def terms_from_trace(tr):
"""Helper function to extract elbo components from execution traces."""
# data structure containing densities, measures, scales, and identification
# of free variables as either product (plate) variables or sum (measure) variables
terms = {"log_factors": [], "log_measures": [], "scale": to_funsor(1.),
"plate_vars": frozenset(), "measure_vars": frozenset()}
for name, node in tr.nodes.items():
if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample":
continue
# grab plate dimensions from the cond_indep_stack
terms["plate_vars"] |= frozenset(f.name for f in node["cond_indep_stack"] if f.vectorized)
# grab the log-measure, found only at sites that are not replayed or observed
if node["funsor"].get("log_measure", None) is not None:
terms["log_measures"].append(node["funsor"]["log_measure"])
# sum (measure) variables: the fresh non-plate variables at a site
terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs) | {name}) - terms["plate_vars"]
# grab the scale, assuming a common subsampling scale
if node.get("replay_active", False) and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] and \
float(to_data(node["funsor"]["scale"])) != 1.:
# model site that depends on enumerated variable: common scale
terms["scale"] = node["funsor"]["scale"]
else: # otherwise: default scale behavior
node["funsor"]["log_prob"] = node["funsor"]["log_prob"] * node["funsor"]["scale"]
# grab the log-density, found at all sites except those that are not replayed
if node["is_observed"] or not node.get("replay_skipped", False):
terms["log_factors"].append(node["funsor"]["log_prob"])
return terms