def update_networks()

in Experiments/PolicyManagers.py [0:0]


	def update_networks(self, dictionary, source_policy_manager):

		# Here are the objectives we have to be considering. 
		# 	1) Reconstruction of inputs under single domain encoding / decoding. 
		#		In this implementation, we just have to use the source_loglikelihood for this. 
		#	2) Discriminability of Z space. This is taken care of from the compute_discriminator_losses function.
		# 	3) Cycle-consistency. This may be implemented as regression (L2), loglikelihood of cycle-reconstructed traj, or discriminability of trajectories.
		#		In this implementation, we just have to use the cross domain decoded loglikelihood. 

		####################################
		# First update encoder decoder networks. Don't train discriminator.
		####################################

		# Zero gradients.
		self.optimizer.zero_grad()		

		####################################
		# (1) Compute single-domain reconstruction loss.
		####################################

		# Compute VAE loss on the current domain as negative log likelihood likelihood plus weighted KL.  
		self.source_likelihood_loss = -dictionary['source_loglikelihood'].mean()
		self.source_encoder_KL = dictionary['source_kl_divergence'].mean()
		self.source_reconstruction_loss = self.source_likelihood_loss + self.args.kl_weight*self.source_encoder_KL

		####################################
		# (2) Compute discriminability losses.
		####################################

		#	This block first computes discriminability losses:
		#	# a) First, feeds the latent_z into the z_discriminator, that is being trained to discriminate between z's of source and target domains. 
		#	# 	 Gets and returns the loglikelihood of the discriminator predicting the true domain. 
		#	# 	 Also returns discriminability loss, that is used to train the _encoders_ of both domains. 
		#	#		
		#	# b) ####### DON'T NEED TO DO THIS YET: ####### Also feeds either the cycle reconstructed trajectory, or the original trajectory from the source domain, into a separate discriminator. 
		#	# 	 This second discriminator is specific to the domain we are operating in. This discriminator is discriminating between the reconstructed and original trajectories. 
		#	# 	 Basically standard GAN adversarial training, except the generative model here is the entire cycle-consistency translation model.
		#
		#	In addition to this, must also compute discriminator losses to train discriminators themselves. 
		# 	# a) For the z discriminator (and if we're using trajectory discriminators, those too), clone and detach the inputs of the discriminator and compute a discriminator loss with the right domain used in targets / supervision. 
		#	#	 This discriminator loss is what is used to actually train the discriminators.		

		# Get z discriminator logprobabilities.
		z_discriminator_logprob, z_discriminator_prob = self.discriminator_network(dictionary['source_latent_z'])
		# Compute discriminability loss. Remember, this is not used for training the discriminator, but rather the encoders.
		self.z_discriminability_loss = self.negative_log_likelihood_loss_function(z_discriminator_logprob.squeeze(1), torch.tensor(1-domain).to(device).long().view(1,))

		###### Block that computes discriminability losses assuming we are using trjaectory discriminators. ######

		# # Get the right trajectory discriminator network.
		# discriminator_list = [self.source_discriminator, self.target_discriminator]		
		# source_discriminator = discriminator_list[domain]

		# # Now feed trajectory to the trajectory discriminator, based on whether it is the source of target discriminator.
		# traj_discriminator_logprob, traj_discriminator_prob = source_discriminator(trajectory)

		# # Compute trajectory discriminability loss, based on whether the trajectory was original or reconstructed.
		# self.traj_discriminability_loss = self.negative_log_likelihood_loss_function(traj_discriminator_logprob.squeeze(1), torch.tensor(1-original_or_reconstructed).to(device).long().view(1,))

		####################################
		# (3) Compute cycle-consistency losses.
		####################################

		# Must compute likelihoods of original actions under the cycle reconstructed trajectory states. 
		# I.e. evaluate likelihood of original actions under source_decoder (i.e. source subpolicy), with the subpolicy inputs constructed from cycle-reconstruction.
		
		# Get the original action sequence.
		original_action_sequence = dictionary['source_subpolicy_inputs_original'][:,self.state_dim:2*self.state_dim]

		# Now evaluate likelihood of actions under the source decoder.
		cycle_reconstructed_loglikelihood, _ = source_policy_manager.forward(dictionary['source_subpolicy_inputs_crossdomain'], original_action_sequence)
		# Reweight the cycle reconstructed likelihood to construct the loss.
		self.cycle_reconstruction_loss = -self.args.cycle_reconstruction_loss_weight*cycle_reconstruction_loss.mean()

		####################################
		# Now that individual losses are computed, compute total loss, compute gradients, and then step.
		####################################

		# First combine losses.
		self.total_VAE_loss = self.source_reconstruction_loss + self.z_discriminability_loss + self.cycle_reconstruction_loss

		# If we are in a encoder / decoder training phase, compute gradients and step.  
		if not(self.skip_vae):
			self.total_VAE_loss.backward()
			self.optimizer.step()

		####################################
		# Now compute discriminator losses and update discriminator network(s).
		####################################

		# First zero out the discriminator gradients. 
		self.discriminator_optimizer.zero_grad()

		# Detach the latent z that is fed to the discriminator, and then compute discriminator loss.
		# If we tried to zero grad the discriminator and then use NLL loss on it again, Pytorch would cry about going backward through a part of the graph that we already \ 
		# went backward through. Instead, just pass things through the discriminator again, but this time detaching latent_z. 
		z_discriminator_detach_logprob, z_discriminator_detach_prob = self.discriminator_network(dictionary['source_latent_z'].detach())

		# Compute discriminator loss for discriminator. 
		self.z_discriminator_loss = self.negative_log_likelihood_loss_function(z_discriminator_detach_logprob.squeeze(1), torch.tensor(domain).to(device).long().view(1,))		
		
		if not(self.skip_discriminator):
			# Now go backward and take a step.
			self.z_discriminator_loss.backward()
			self.discriminator_optimizer.step()