in pyro/infer/energy_distance.py [0:0]
def _get_traces(self, model, guide, args, kwargs):
if self.max_plate_nesting == float("inf"):
with validation_enabled(False): # Avoid calling .log_prob() when undefined.
# TODO factor this out as a stand-alone helper.
ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
vectorize = pyro.plate("num_particles_vectorized", self.num_particles,
dim=-self.max_plate_nesting)
# Trace the guide as in ELBO.
with poutine.trace() as tr, vectorize:
guide(*args, **kwargs)
guide_trace = tr.trace
# Trace the model, drawing posterior predictive samples.
with poutine.trace() as tr, poutine.uncondition():
with poutine.replay(trace=guide_trace), vectorize:
model(*args, **kwargs)
model_trace = tr.trace
for site in model_trace.nodes.values():
if site["type"] == "sample" and site["infer"].get("was_observed", False):
site["is_observed"] = True
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
if is_validation_enabled():
for site in guide_trace.nodes.values():
if site["type"] == "sample":
warn_if_nan(site["value"], site["name"])
if not getattr(site["fn"], "has_rsample", False):
raise ValueError("EnergyDistance requires fully reparametrized guides")
for trace in model_trace.nodes.values():
if site["type"] == "sample":
if site["is_observed"]:
warn_if_nan(site["value"], site["name"])
if not getattr(site["fn"], "has_rsample", False):
raise ValueError("EnergyDistance requires reparametrized likelihoods")
if self.prior_scale > 0:
model_trace.compute_log_prob(site_filter=lambda name, site: not site["is_observed"])
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
if not site["is_observed"]:
check_site_shape(site, self.max_plate_nesting)
return guide_trace, model_trace