in pyro/infer/renyi_elbo.py [0:0]
def loss_and_grads(self, model, guide, *args, **kwargs):
"""
:returns: returns an estimate of the ELBO
:rtype: float
Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
Performs backward on the latter. Num_particle many samples are used to form the estimators.
"""
elbo_particles = []
surrogate_elbo_particles = []
is_vectorized = self.vectorize_particles and self.num_particles > 1
tensor_holder = None
# grab a vectorized trace from the generator
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
elbo_particle = 0
surrogate_elbo_particle = 0
sum_dims = get_dependent_plate_dims(model_trace.nodes.values())
# compute elbo and surrogate elbo
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
log_prob_sum = torch_sum(site["log_prob"], sum_dims)
elbo_particle = elbo_particle + log_prob_sum.detach()
surrogate_elbo_particle = surrogate_elbo_particle + log_prob_sum
for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
log_prob, score_function_term, entropy_term = site["score_parts"]
log_prob_sum = torch_sum(site["log_prob"], sum_dims)
elbo_particle = elbo_particle - log_prob_sum.detach()
if not is_identically_zero(entropy_term):
surrogate_elbo_particle = surrogate_elbo_particle - log_prob_sum
if not is_identically_zero(score_function_term):
# link to the issue: https://github.com/pyro-ppl/pyro/issues/1222
raise NotImplementedError
if not is_identically_zero(score_function_term):
surrogate_elbo_particle = (surrogate_elbo_particle +
(self.alpha / (1. - self.alpha)) * log_prob_sum)
if is_identically_zero(elbo_particle):
if tensor_holder is not None:
elbo_particle = torch.zeros_like(tensor_holder)
surrogate_elbo_particle = torch.zeros_like(tensor_holder)
else: # elbo_particle is not None
if tensor_holder is None:
tensor_holder = torch.zeros_like(elbo_particle)
# change types of previous `elbo_particle`s
for i in range(len(elbo_particles)):
elbo_particles[i] = torch.zeros_like(tensor_holder)
surrogate_elbo_particles[i] = torch.zeros_like(tensor_holder)
elbo_particles.append(elbo_particle)
surrogate_elbo_particles.append(surrogate_elbo_particle)
if tensor_holder is None:
return 0.
if is_vectorized:
elbo_particles = elbo_particles[0]
surrogate_elbo_particles = surrogate_elbo_particles[0]
else:
elbo_particles = torch.stack(elbo_particles)
surrogate_elbo_particles = torch.stack(surrogate_elbo_particles)
log_weights = (1. - self.alpha) * elbo_particles
log_mean_weight = torch.logsumexp(log_weights, dim=0, keepdim=True) - math.log(self.num_particles)
elbo = log_mean_weight.sum().item() / (1. - self.alpha)
# collect parameters to train from model and guide
trainable_params = any(site["type"] == "param"
for trace in (model_trace, guide_trace)
for site in trace.nodes.values())
if trainable_params and getattr(surrogate_elbo_particles, 'requires_grad', False):
normalized_weights = (log_weights - log_mean_weight).exp()
surrogate_elbo = (normalized_weights * surrogate_elbo_particles).sum() / self.num_particles
surrogate_loss = -surrogate_elbo
surrogate_loss.backward()
loss = -elbo
warn_if_nan(loss, "loss")
return loss