in pyro/infer/rws.py [0:0]
def _loss(self, model, guide, args, kwargs):
"""
:returns: returns model loss and guide loss
:rtype: float, float
Computes the re-weighted wake-sleep estimators for the model (wake-theta) and the
guide (insomnia * wake-phi + (1 - insomnia) * sleep-phi).
Performs backward as appropriate on both, over the specified number of particles.
"""
wake_theta_loss = torch.tensor(100.)
if self.model_has_params or self.insomnia > 0.:
# compute quantities for wake theta and wake phi
log_joints = []
log_qs = []
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
log_joint = 0.
log_q = 0.
for _, site in model_trace.nodes.items():
if site["type"] == "sample":
if self.vectorize_particles:
log_p_site = site["log_prob"].reshape(self.num_particles, -1).sum(-1)
else:
log_p_site = site["log_prob_sum"]
log_joint = log_joint + log_p_site
for _, site in guide_trace.nodes.items():
if site["type"] == "sample":
if self.vectorize_particles:
log_q_site = site["log_prob"].reshape(self.num_particles, -1).sum(-1)
else:
log_q_site = site["log_prob_sum"]
log_q = log_q + log_q_site
log_joints.append(log_joint)
log_qs.append(log_q)
log_joints = log_joints[0] if self.vectorize_particles else torch.stack(log_joints)
log_qs = log_qs[0] if self.vectorize_particles else torch.stack(log_qs)
log_weights = log_joints - log_qs.detach()
# compute wake theta loss
log_sum_weight = torch.logsumexp(log_weights, dim=0)
wake_theta_loss = -(log_sum_weight - math.log(self.num_particles)).sum()
warn_if_nan(wake_theta_loss, "wake theta loss")
if self.insomnia > 0:
# compute wake phi loss
normalised_weights = (log_weights - log_sum_weight).exp().detach()
wake_phi_loss = -(normalised_weights * log_qs).sum()
warn_if_nan(wake_phi_loss, "wake phi loss")
if self.insomnia < 1:
# compute sleep phi loss
_model = pyro.poutine.uncondition(model)
_guide = guide
_log_q = 0.
if self.vectorize_particles:
if self.max_plate_nesting == float('inf'):
self._guess_max_plate_nesting(_model, _guide, args, kwargs)
_model = self._vectorized_num_sleep_particles(_model)
_guide = self._vectorized_num_sleep_particles(guide)
for _ in range(1 if self.vectorize_particles else self.num_sleep_particles):
_model_trace = poutine.trace(_model).get_trace(*args, **kwargs)
_model_trace.detach_()
_guide_trace = self._get_matched_trace(_model_trace, _guide, args, kwargs)
_log_q += _guide_trace.log_prob_sum()
sleep_phi_loss = -_log_q / self.num_sleep_particles
warn_if_nan(sleep_phi_loss, "sleep phi loss")
# compute phi loss
phi_loss = sleep_phi_loss if self.insomnia == 0 \
else wake_phi_loss if self.insomnia == 1 \
else self.insomnia * wake_phi_loss + (1. - self.insomnia) * sleep_phi_loss
return wake_theta_loss, phi_loss