def _compute_model_factors()

in pyro/infer/traceenum_elbo.py [0:0]


def _compute_model_factors(model_trace, guide_trace):
    # y depends on x iff ordering[x] <= ordering[y]
    # TODO refine this coarse dependency ordering using time.
    ordering = {name: _find_ordinal(trace, site)
                for trace in (model_trace, guide_trace)
                for name, site in trace.nodes.items()
                if site["type"] == "sample"}

    # Collect model sites that may have been enumerated in the model.
    cost_sites = OrderedDict()
    enum_sites = OrderedDict()
    enum_dims = set()
    non_enum_dims = set().union(*ordering.values())
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            if name in guide_trace.nodes:
                cost_sites.setdefault(ordering[name], []).append(site)
                non_enum_dims.update(guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims)
            elif site["infer"].get("_enumerate_dim") is None:
                cost_sites.setdefault(ordering[name], []).append(site)
            else:
                enum_sites.setdefault(ordering[name], []).append(site)
                enum_dims.update(site["packed"]["log_prob"]._pyro_dims)
    enum_dims -= non_enum_dims
    log_factors = OrderedDict()
    scale = 1
    if not enum_sites:
        marginal_costs = OrderedDict((t, [site["packed"]["log_prob"] for site in sites_t])
                                     for t, sites_t in cost_sites.items())
        return marginal_costs, log_factors, ordering, enum_dims, scale
    _check_model_guide_enumeration_constraint(enum_sites, guide_trace)

    # Marginalize out all variables that have been enumerated in the model.
    marginal_costs = OrderedDict()
    scales = []
    for t, sites_t in cost_sites.items():
        for site in sites_t:
            if enum_dims.isdisjoint(site["packed"]["log_prob"]._pyro_dims):
                # For sites that do not depend on an enumerated variable, proceed as usual.
                marginal_costs.setdefault(t, []).append(site["packed"]["log_prob"])
            else:
                # For sites that depend on an enumerated variable, we need to apply
                # the mask inside- and the scale outside- of the log expectation.
                if "masked_log_prob" not in site["packed"]:
                    site["packed"]["masked_log_prob"] = packed.scale_and_mask(
                        site["packed"]["unscaled_log_prob"], mask=site["packed"]["mask"])
                cost = site["packed"]["masked_log_prob"]
                log_factors.setdefault(t, []).append(cost)
                scales.append(site["scale"])
    if log_factors:
        for t, sites_t in enum_sites.items():
            # TODO refine this coarse dependency ordering using time and tensor shapes.
            if any(t <= u for u in log_factors):
                for site in sites_t:
                    logprob = site["packed"]["unscaled_log_prob"]
                    log_factors.setdefault(t, []).append(logprob)
                    scales.append(site["scale"])
        scale = _get_common_scale(scales)
    return marginal_costs, log_factors, ordering, enum_dims, scale