in Experiments/PolicyManagers.py [0:0]
def update_networks(self, dictionary, source_policy_manager):
# Here are the objectives we have to be considering.
# 1) Reconstruction of inputs under single domain encoding / decoding.
# In this implementation, we just have to use the source_loglikelihood for this.
# 2) Discriminability of Z space. This is taken care of from the compute_discriminator_losses function.
# 3) Cycle-consistency. This may be implemented as regression (L2), loglikelihood of cycle-reconstructed traj, or discriminability of trajectories.
# In this implementation, we just have to use the cross domain decoded loglikelihood.
####################################
# First update encoder decoder networks. Don't train discriminator.
####################################
# Zero gradients.
self.optimizer.zero_grad()
####################################
# (1) Compute single-domain reconstruction loss.
####################################
# Compute VAE loss on the current domain as negative log likelihood likelihood plus weighted KL.
self.source_likelihood_loss = -dictionary['source_loglikelihood'].mean()
self.source_encoder_KL = dictionary['source_kl_divergence'].mean()
self.source_reconstruction_loss = self.source_likelihood_loss + self.args.kl_weight*self.source_encoder_KL
####################################
# (2) Compute discriminability losses.
####################################
# This block first computes discriminability losses:
# # a) First, feeds the latent_z into the z_discriminator, that is being trained to discriminate between z's of source and target domains.
# # Gets and returns the loglikelihood of the discriminator predicting the true domain.
# # Also returns discriminability loss, that is used to train the _encoders_ of both domains.
# #
# # b) ####### DON'T NEED TO DO THIS YET: ####### Also feeds either the cycle reconstructed trajectory, or the original trajectory from the source domain, into a separate discriminator.
# # This second discriminator is specific to the domain we are operating in. This discriminator is discriminating between the reconstructed and original trajectories.
# # Basically standard GAN adversarial training, except the generative model here is the entire cycle-consistency translation model.
#
# In addition to this, must also compute discriminator losses to train discriminators themselves.
# # a) For the z discriminator (and if we're using trajectory discriminators, those too), clone and detach the inputs of the discriminator and compute a discriminator loss with the right domain used in targets / supervision.
# # This discriminator loss is what is used to actually train the discriminators.
# Get z discriminator logprobabilities.
z_discriminator_logprob, z_discriminator_prob = self.discriminator_network(dictionary['source_latent_z'])
# Compute discriminability loss. Remember, this is not used for training the discriminator, but rather the encoders.
self.z_discriminability_loss = self.negative_log_likelihood_loss_function(z_discriminator_logprob.squeeze(1), torch.tensor(1-domain).to(device).long().view(1,))
###### Block that computes discriminability losses assuming we are using trjaectory discriminators. ######
# # Get the right trajectory discriminator network.
# discriminator_list = [self.source_discriminator, self.target_discriminator]
# source_discriminator = discriminator_list[domain]
# # Now feed trajectory to the trajectory discriminator, based on whether it is the source of target discriminator.
# traj_discriminator_logprob, traj_discriminator_prob = source_discriminator(trajectory)
# # Compute trajectory discriminability loss, based on whether the trajectory was original or reconstructed.
# self.traj_discriminability_loss = self.negative_log_likelihood_loss_function(traj_discriminator_logprob.squeeze(1), torch.tensor(1-original_or_reconstructed).to(device).long().view(1,))
####################################
# (3) Compute cycle-consistency losses.
####################################
# Must compute likelihoods of original actions under the cycle reconstructed trajectory states.
# I.e. evaluate likelihood of original actions under source_decoder (i.e. source subpolicy), with the subpolicy inputs constructed from cycle-reconstruction.
# Get the original action sequence.
original_action_sequence = dictionary['source_subpolicy_inputs_original'][:,self.state_dim:2*self.state_dim]
# Now evaluate likelihood of actions under the source decoder.
cycle_reconstructed_loglikelihood, _ = source_policy_manager.forward(dictionary['source_subpolicy_inputs_crossdomain'], original_action_sequence)
# Reweight the cycle reconstructed likelihood to construct the loss.
self.cycle_reconstruction_loss = -self.args.cycle_reconstruction_loss_weight*cycle_reconstruction_loss.mean()
####################################
# Now that individual losses are computed, compute total loss, compute gradients, and then step.
####################################
# First combine losses.
self.total_VAE_loss = self.source_reconstruction_loss + self.z_discriminability_loss + self.cycle_reconstruction_loss
# If we are in a encoder / decoder training phase, compute gradients and step.
if not(self.skip_vae):
self.total_VAE_loss.backward()
self.optimizer.step()
####################################
# Now compute discriminator losses and update discriminator network(s).
####################################
# First zero out the discriminator gradients.
self.discriminator_optimizer.zero_grad()
# Detach the latent z that is fed to the discriminator, and then compute discriminator loss.
# If we tried to zero grad the discriminator and then use NLL loss on it again, Pytorch would cry about going backward through a part of the graph that we already \
# went backward through. Instead, just pass things through the discriminator again, but this time detaching latent_z.
z_discriminator_detach_logprob, z_discriminator_detach_prob = self.discriminator_network(dictionary['source_latent_z'].detach())
# Compute discriminator loss for discriminator.
self.z_discriminator_loss = self.negative_log_likelihood_loss_function(z_discriminator_detach_logprob.squeeze(1), torch.tensor(domain).to(device).long().view(1,))
if not(self.skip_discriminator):
# Now go backward and take a step.
self.z_discriminator_loss.backward()
self.discriminator_optimizer.step()