in pyro/infer/util.py [0:0]
def compute_expectation(self, costs):
"""
Returns a differentiable expected cost, summing over costs at given ordinals.
:param dict costs: A dict mapping ordinals to lists of cost tensors
:returns: a scalar expected cost
:rtype: torch.Tensor or float
"""
# Share computation across all cost terms.
with shared_intermediates() as cache:
ring = MarginalRing(cache=cache)
expected_cost = 0.
for ordinal, cost_terms in costs.items():
log_factors = self._get_log_factors(ordinal)
scale = math.exp(sum(x for x in log_factors if not isinstance(x, torch.Tensor)))
log_factors = [x for x in log_factors if isinstance(x, torch.Tensor)]
# Collect log_prob terms to query for marginal probability.
queries = {frozenset(cost._pyro_dims): None for cost in cost_terms}
for log_factor in log_factors:
key = frozenset(log_factor._pyro_dims)
if queries.get(key, False) is None:
queries[key] = log_factor
# Ensure a query exists for each cost term.
for cost in cost_terms:
key = frozenset(cost._pyro_dims)
if queries[key] is None:
query = torch.zeros_like(cost)
query._pyro_dims = cost._pyro_dims
log_factors.append(query)
queries[key] = query
# Perform sum-product contraction. Note that plates never need to be
# product-contracted due to our plate-based dependency ordering.
sum_dims = set().union(*(x._pyro_dims for x in log_factors)) - ordinal
for query in queries.values():
require_backward(query)
root = ring.sumproduct(log_factors, sum_dims)
root._pyro_backward()
probs = {key: query._pyro_backward_result.exp() for key, query in queries.items()}
# Aggregate prob * cost terms.
for cost in cost_terms:
key = frozenset(cost._pyro_dims)
prob = probs[key]
prob._pyro_dims = queries[key]._pyro_dims
mask = prob > 0
if torch._C._get_tracing_state() or not mask.all():
mask._pyro_dims = prob._pyro_dims
cost, prob, mask = packed.broadcast_all(cost, prob, mask)
prob = prob.masked_select(mask)
cost = cost.masked_select(mask)
else:
cost, prob = packed.broadcast_all(cost, prob)
expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
LAST_CACHE_SIZE[0] = count_cached_ops(cache)
return expected_cost