def compute_expectation()

in pyro/infer/util.py [0:0]


    def compute_expectation(self, costs):
        """
        Returns a differentiable expected cost, summing over costs at given ordinals.

        :param dict costs: A dict mapping ordinals to lists of cost tensors
        :returns: a scalar expected cost
        :rtype: torch.Tensor or float
        """
        # Share computation across all cost terms.
        with shared_intermediates() as cache:
            ring = MarginalRing(cache=cache)
            expected_cost = 0.
            for ordinal, cost_terms in costs.items():
                log_factors = self._get_log_factors(ordinal)
                scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor)))
                log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)]

                # Collect log_prob terms to query for marginal probability.
                queries = {frozenset(cost._pyro_dims): None for cost in cost_terms}
                for log_factor in log_factors:
                    key = frozenset(log_factor._pyro_dims)
                    if queries.get(key, False) is None:
                        queries[key] = log_factor
                # Ensure a query exists for each cost term.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    if queries[key] is None:
                        query = torch.zeros_like(cost)
                        query._pyro_dims = cost._pyro_dims
                        log_factors.append(query)
                        queries[key] = query

                # Perform sum-product contraction. Note that plates never need to be
                # product-contracted due to our plate-based dependency ordering.
                sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal
                for query in queries.values():
                    require_backward(query)
                root = ring.sumproduct(log_factors, sum_dims)
                root._pyro_backward()
                probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()}

                # Aggregate prob * cost terms.
                for cost in cost_terms:
                    key = frozenset(cost._pyro_dims)
                    prob = probs[key]
                    prob._pyro_dims = queries[key]._pyro_dims
                    mask = prob > 0
                    if torch._C._get_tracing_state() or not mask.all():
                        mask._pyro_dims = prob._pyro_dims
                        cost, prob, mask = packed.broadcast_all(cost, prob, mask)
                        prob = prob.masked_select(mask)
                        cost = cost.masked_select(mask)
                    else:
                        cost, prob = packed.broadcast_all(cost, prob)
                    expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())

        LAST_CACHE_SIZE[0] = count_cached_ops(cache)
        return expected_cost