in Experiments/PolicyManagers.py [0:0]
def run_iteration(self, counter, i):
# Phases:
# Phase 1: Train encoder-decoder for both domains initially, so that discriminator is not fed garbage.
# Phase 2: Train encoder, decoder for each domain, and Z discriminator concurrently.
# Phase 3: Train encoder, decoder for each domain, and the individual source and target discriminators, concurrently.
# Algorithm (joint training):
# For every epoch:
# # For every datapoint:
# # 1) Select which domain to use as source (i.e. with 50% chance, select either domain).
# # 2) Get trajectory segments from desired domain.
# # 3) Transfer Steps:
# # a) Encode trajectory as latent z (domain 1).
# # b) Use domain 2 decoder to decode latent z into trajectory (domain 2).
# # c) Use domain 2 encoder to encode trajectory into latent z (domain 2).
# # d) Use domain 1 decoder to decode latent z (domain 2) into trajectory (domain 1).
# # 4) Feed cycle-reconstructed trajectory and original trajectory (both domain 1) into discriminator.
# # 5) Train discriminators to predict whether original or cycle reconstructed trajectory.
# # Alternate: Remember, don't actually need to use trajectory level discriminator networks, can just use loglikelihood cycle-reconstruction loss. Try this first.
# # Train z discriminator to predict which domain the latentz sample came from.
# # Train encoder / decoder architectures with mix of reconstruction loss and discriminator confusing objective.
# # Compute and apply gradient updates.
# Remember to make domain agnostic function calls to encode, feed into discriminator, get likelihoods, etc.
####################################
# (0) Setup things like training phases, epislon values, etc.
####################################
self.set_iteration(counter)
dictionary = {}
target_dict = {}
####################################
# (1) Select which domain to use as source domain (also supervision of z discriminator for this iteration).
####################################
domain, source_policy_manager, target_policy_manager = self.get_source_target_domain_managers()
####################################
# (2) & (3 a) Get source trajectory (segment) and encode into latent z. Decode using source decoder, to get loglikelihood for reconstruction objectve.
####################################
dictionary['source_subpolicy_inputs_original'], dictionary['source_latent_z'], dictionary['source_loglikelihood'], dictionary['source_kl_divergence'] = self.encode_decode_trajectory(source_policy_manager, i)
####################################
# (3 b) Cross domain decoding.
####################################
target_dict['target_trajectory_rollout'], target_dict['target_subpolicy_inputs'] = self.cross_domain_decoding(domain, target_policy_manager, dictionary['source_latent_z'])
####################################
# (3 c) Cross domain encoding of target_trajectory_rollout into target latent_z.
####################################
dictionary['target_subpolicy_inputs'], dictionary['target_latent_z'], dictionary['target_loglikelihood'], dictionary['target_kl_divergence'] = self.encode_decode_trajectory(target_policy_manager, i, trajectory_input=target_dict)
####################################
# (3 d) Cross domain decoding of target_latent_z into source trajectory.
# Can use the original start state, or also use the reverse trick for start state. Try both maybe.
####################################
source_trajectory_rollout, dictionary['source_subpolicy_inputs_crossdomain'] = self.cross_domain_decoding(domain, source_policy_manager, dictionary['target_latent_z'], start_state=dictionary['source_subpolicy_inputs'][0,:self.state_dim].detach().cpu().numpy())
####################################
# (4) Feed source and target latent z's to z_discriminator.
####################################
self.compute_discriminator_losses(domain, dictionary['source_latent_z'])
####################################
# (5) Compute all losses, reweight, and take gradient steps.
####################################
self.update_networks(dictionary, source_policy_manager)