in pyro/infer/tracegraph_elbo.py [0:0]
def _compute_downstream_costs(model_trace, guide_trace, #
non_reparam_nodes):
# recursively compute downstream cost nodes for all sample sites in model and guide
# (even though ultimately just need for non-reparameterizable sample sites)
# 1. downstream costs used for rao-blackwellization
# 2. model observe sites (as well as terms that arise from the model and guide having different
# dependency structures) are taken care of via 'children_in_model' below
topo_sort_guide_nodes = guide_trace.topological_sort(reverse=True)
topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes
if guide_trace.nodes[x]["type"] == "sample"]
ordered_guide_nodes_dict = {n: i for i, n in enumerate(topo_sort_guide_nodes)}
downstream_guide_cost_nodes = {}
downstream_costs = {}
stacks = get_plate_stacks(model_trace)
for node in topo_sort_guide_nodes:
downstream_costs[node] = MultiFrameTensor((stacks[node],
model_trace.nodes[node]['log_prob'] -
guide_trace.nodes[node]['log_prob']))
nodes_included_in_sum = set([node])
downstream_guide_cost_nodes[node] = set([node])
# make more efficient by ordering children appropriately (higher children first)
children = [(k, -ordered_guide_nodes_dict[k]) for k in guide_trace.successors(node)]
sorted_children = sorted(children, key=itemgetter(1))
for child, _ in sorted_children:
child_cost_nodes = downstream_guide_cost_nodes[child]
downstream_guide_cost_nodes[node].update(child_cost_nodes)
if nodes_included_in_sum.isdisjoint(child_cost_nodes): # avoid duplicates
downstream_costs[node].add(*downstream_costs[child].items())
# XXX nodes_included_in_sum logic could be more fine-grained, possibly leading
# to speed-ups in case there are many duplicates
nodes_included_in_sum.update(child_cost_nodes)
missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum
# include terms we missed because we had to avoid duplicates
for missing_node in missing_downstream_costs:
downstream_costs[node].add((stacks[missing_node],
model_trace.nodes[missing_node]['log_prob'] -
guide_trace.nodes[missing_node]['log_prob']))
# finish assembling complete downstream costs
# (the above computation may be missing terms from model)
for site in non_reparam_nodes:
children_in_model = set()
for node in downstream_guide_cost_nodes[site]:
children_in_model.update(model_trace.successors(node))
# remove terms accounted for above
children_in_model.difference_update(downstream_guide_cost_nodes[site])
for child in children_in_model:
assert (model_trace.nodes[child]["type"] == "sample")
downstream_costs[site].add((stacks[child],
model_trace.nodes[child]['log_prob']))
downstream_guide_cost_nodes[site].update([child])
for k in non_reparam_nodes:
downstream_costs[k] = downstream_costs[k].sum_to(guide_trace.nodes[k]["cond_indep_stack"])
return downstream_costs, downstream_guide_cost_nodes