in pyro/infer/discrete.py [0:0]
def _sample_posterior_from_trace(model, enum_trace, temperature, *args, **kwargs):
plate_to_symbol = enum_trace.plate_to_symbol
# Collect a set of query sample sites to which the backward algorithm will propagate.
sum_dims = set()
queries = []
dim_to_size = {}
cost_terms = OrderedDict()
enum_terms = OrderedDict()
for node in enum_trace.nodes.values():
if node["type"] == "sample":
ordinal = frozenset(plate_to_symbol[f.name]
for f in node["cond_indep_stack"]
if f.vectorized and f.size > 1)
# For sites that depend on an enumerated variable, we need to apply
# the mask but not the scale when sampling.
if "masked_log_prob" not in node["packed"]:
node["packed"]["masked_log_prob"] = packed.scale_and_mask(
node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"])
log_prob = node["packed"]["masked_log_prob"]
sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal)
if sum_dims.isdisjoint(log_prob._pyro_dims):
continue
dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape))
if node["infer"].get("_enumerate_dim") is None:
cost_terms.setdefault(ordinal, []).append(log_prob)
else:
enum_terms.setdefault(ordinal, []).append(log_prob)
# Note we mark all sample sites with require_backward to gather
# enumerated sites and adjust cond_indep_stack of all sample sites.
if not node["is_observed"]:
queries.append(log_prob)
require_backward(log_prob)
# We take special care to match the term ordering in
# pyro.infer.traceenum_elbo._compute_model_factors() to allow
# contract_tensor_tree() to use shared_intermediates() inside
# TraceEnumSample_ELBO. The special ordering is: first all cost terms in
# order of model_trace, then all enum_terms in order of model trace.
log_probs = cost_terms
for ordinal, terms in enum_terms.items():
log_probs.setdefault(ordinal, []).extend(terms)
# Run forward-backward algorithm, collecting the ordinal of each connected component.
cache = getattr(enum_trace, "_sharing_cache", {})
ring = _make_ring(temperature, cache, dim_to_size)
with shared_intermediates(cache):
log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm
query_to_ordinal = {}
pending = object() # a constant value for pending queries
for query in queries:
query._pyro_backward_result = pending
for ordinal, terms in log_probs.items():
for term in terms:
if hasattr(term, "_pyro_backward"):
term._pyro_backward() # run backward algorithm
# Note: this is quadratic in number of ordinals
for query in queries:
if query not in query_to_ordinal and query._pyro_backward_result is not pending:
query_to_ordinal[query] = ordinal
# Construct a collapsed trace by gathering and adjusting cond_indep_stack.
collapsed_trace = poutine.Trace()
for node in enum_trace.nodes.values():
if node["type"] == "sample" and not node["is_observed"]:
# TODO move this into a Leaf implementation somehow
new_node = {
"type": "sample",
"name": node["name"],
"is_observed": False,
"infer": node["infer"].copy(),
"cond_indep_stack": node["cond_indep_stack"],
"value": node["value"],
}
log_prob = node["packed"]["masked_log_prob"]
if hasattr(log_prob, "_pyro_backward_result"):
# Adjust the cond_indep_stack.
ordinal = query_to_ordinal[log_prob]
new_node["cond_indep_stack"] = tuple(
f for f in node["cond_indep_stack"]
if not (f.vectorized and f.size > 1) or plate_to_symbol[f.name] in ordinal)
# Gather if node depended on an enumerated value.
sample = log_prob._pyro_backward_result
if sample is not None:
new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"])
for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims):
if dim in new_value._pyro_dims:
index._pyro_dims = sample._pyro_dims[1:]
new_value = packed.gather(new_value, index, dim)
new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim)
collapsed_trace.add_node(node["name"], **new_node)
# Replay the model against the collapsed trace.
with SamplePosteriorMessenger(trace=collapsed_trace):
return model(*args, **kwargs)