in pyro/contrib/funsor/infer/traceenum_elbo.py [0:0]
def differentiable_loss(self, model, guide, *args, **kwargs):
# get batched, enumerated, to_funsor-ed traces from the guide and model
with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \
enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None):
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
# extract from traces all metadata that we will need to compute the elbo
guide_terms = terms_from_trace(guide_tr)
model_terms = terms_from_trace(model_tr)
# build up a lazy expression for the elbo
with funsor.interpreter.interpretation(funsor.terms.lazy):
# identify and contract out auxiliary variables in the model with partial_sum_product
contracted_factors, uncontracted_factors = [], []
for f in model_terms["log_factors"]:
if model_terms["measure_vars"].intersection(f.inputs):
contracted_factors.append(f)
else:
uncontracted_factors.append(f)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
contracted_costs = [model_terms["scale"] * f for f in funsor.sum_product.partial_sum_product(
funsor.ops.logaddexp, funsor.ops.add,
model_terms["log_measures"] + contracted_factors,
plates=model_terms["plate_vars"], eliminate=model_terms["measure_vars"]
)]
costs = contracted_costs + uncontracted_factors # model costs: logp
costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq
# finally, integrate out guide variables in the elbo and all plates
plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
elbo = to_funsor(0, output=funsor.Real)
for cost in costs:
# compute the marginal logq in the guide corresponding to this cost term
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp, funsor.ops.add,
guide_terms["log_measures"],
plates=plate_vars,
eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs)
)
# compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q
elbo_term = funsor.Integrate(log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs))
elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs))
# evaluate the elbo, using memoize to share tensor computation where possible
with funsor.memoize.memoize():
return -to_data(funsor.optimizer.apply_optimizer(elbo))