def terms_from_trace()

in pyro/contrib/funsor/infer/traceenum_elbo.py [0:0]


def terms_from_trace(tr):
    """Helper function to extract elbo components from execution traces."""
    # data structure containing densities, measures, scales, and identification
    # of free variables as either product (plate) variables or sum (measure) variables
    terms = {"log_factors": [], "log_measures": [], "scale": to_funsor(1.),
             "plate_vars": frozenset(), "measure_vars": frozenset()}
    for name, node in tr.nodes.items():
        if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample":
            continue
        # grab plate dimensions from the cond_indep_stack
        terms["plate_vars"] |= frozenset(f.name for f in node["cond_indep_stack"] if f.vectorized)
        # grab the log-measure, found only at sites that are not replayed or observed
        if node["funsor"].get("log_measure", None) is not None:
            terms["log_measures"].append(node["funsor"]["log_measure"])
            # sum (measure) variables: the fresh non-plate variables at a site
            terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs) | {name}) - terms["plate_vars"]
        # grab the scale, assuming a common subsampling scale
        if node.get("replay_active", False) and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] and \
                float(to_data(node["funsor"]["scale"])) != 1.:
            # model site that depends on enumerated variable: common scale
            terms["scale"] = node["funsor"]["scale"]
        else:  # otherwise: default scale behavior
            node["funsor"]["log_prob"] = node["funsor"]["log_prob"] * node["funsor"]["scale"]
        # grab the log-density, found at all sites except those that are not replayed
        if node["is_observed"] or not node.get("replay_skipped", False):
            terms["log_factors"].append(node["funsor"]["log_prob"])
    return terms