in Experiments/PolicyManagers.py [0:0]
def new_update_policies(self, i, sample_action_seq, subpolicy_loglikelihoods, subpolicy_entropy, latent_b, latent_z_indices,\
variational_z_logprobabilities, variational_b_logprobabilities, variational_z_probabilities, variational_b_probabilities, kl_divergence, \
latent_z_logprobabilities, latent_b_logprobabilities, latent_z_probabilities, latent_b_probabilities, \
learnt_subpolicy_loglikelihood, learnt_subpolicy_loglikelihoods, loglikelihood, prior_loglikelihood, latent_loglikelihood, temporal_loglikelihoods):
# Set optimizer gradients to zero.
self.optimizer.zero_grad()
# Assemble prior and KL divergence losses.
# Since these are output by the variational network, and we don't really need the last z predicted by it.
prior_loglikelihood = prior_loglikelihood[:-1]
kl_divergence = kl_divergence[:-1]
######################################################
############## Update latent policy. #################
######################################################
# Remember, an NLL loss function takes <Probabilities, Sampled Value> as arguments.
self.latent_b_loss = self.negative_log_likelihood_loss_function(latent_b_logprobabilities, latent_b.long())
if self.args.discrete_z:
self.latent_z_loss = self.negative_log_likelihood_loss_function(latent_z_logprobabilities, latent_z_indices.long())
# If continuous latent_z, just calculate loss as negative log likelihood of the latent_z's selected by variational network.
else:
self.latent_z_loss = -latent_z_logprobabilities.squeeze(1)
# Compute total latent loss as weighted sum of latent_b_loss and latent_z_loss.
self.total_latent_loss = (self.latent_b_loss_weight*self.latent_b_loss+self.latent_z_loss_weight*self.latent_z_loss)[:-1]
#######################################################
############# Compute Variational Losses ##############
#######################################################
# MUST ALWAYS COMPUTE: # Compute cross entropies.
self.variational_b_loss = self.negative_log_likelihood_loss_function(variational_b_logprobabilities[:-1], latent_b[:-1].long())
# In case of reparameterization, the variational loss that goes to REINFORCE should just be variational_b_loss.
self.variational_loss = self.args.var_loss_weight*self.variational_b_loss
#######################################################
########## Compute Variational Reinforce Loss #########
#######################################################
# Compute reinforce target based on how we express the objective:
# The original implementation, i.e. the entropic implementation, uses:
# (1) \mathbb{E}_{x, z \sim q(z|x)} \Big[ \nabla_{\omega} \log q(z|x,\omega) \{ \log p(x||z) + \log p(z||x) - \log q(z|x) - 1 \} \Big]
# The KL divergence implementation uses:
# (2) \mathbb{E}_{x, z \sim q(z|x)} \Big[ \nabla_{\omega} \log q(z|x,\omega) \{ \log p(x||z) + \log p(z||x) - \log p(z) \} \Big] - \nabla_{\omega} D_{KL} \Big[ q(z|x) || p(z) \Big]
# Compute baseline target according to NEW GRADIENT, and Equation (2) above.
baseline_target = (temporal_loglikelihoods - self.args.prior_weight*prior_loglikelihood).clone().detach()
if self.baseline is None:
self.baseline = torch.zeros_like(baseline_target.mean()).to(device).float()
else:
self.baseline = (self.beta_decay*self.baseline)+(1.-self.beta_decay)*baseline_target.mean()
self.reinforce_variational_loss = self.variational_loss*(baseline_target-self.baseline)
# If reparam, the variational loss is a combination of three things.
# Losses from latent policy and subpolicy into variational network for the latent_z's, the reinforce loss on the latent_b's, and the KL divergence.
# But since we don't need to additionall compute the gradients from latent and subpolicy into variational network, just set the variational loss to reinforce + KL.
# self.total_variational_loss = (self.reinforce_variational_loss.sum() + self.args.kl_weight*kl_divergence.squeeze(1).sum()).sum()
self.total_variational_loss = (self.reinforce_variational_loss + self.args.kl_weight*kl_divergence.squeeze(1)).mean()
######################################################
# Set other losses, subpolicy, latent, and prior.
######################################################
# Get subpolicy losses.
self.subpolicy_loss = (-learnt_subpolicy_loglikelihood).mean()
# Get prior losses.
self.prior_loss = (-self.args.prior_weight*prior_loglikelihood).mean()
# Reweight latent loss.
self.total_weighted_latent_loss = (self.args.latent_loss_weight*self.total_latent_loss).mean()
################################################
# Setting total loss based on phase of training.
################################################
# IF PHASE ONE:
if self.training_phase==1:
self.total_loss = self.subpolicy_loss + self.total_variational_loss + self.prior_loss
# IF DONE WITH PHASE ONE:
elif self.training_phase==2 or self.training_phase==3:
self.total_loss = self.subpolicy_loss + self.total_weighted_latent_loss + self.total_variational_loss + self.prior_loss
################################################
if self.args.debug:
if self.iter%self.args.debug==0:
print("Embedding in Update Policies")
embed()
################################################
self.total_loss.sum().backward()
self.optimizer.step()