in Experiments/PolicyManagers.py [0:0]
def run_iteration(self, counter, i):
# With learnt discrete subpolicy:
# For all epochs:
# # For all trajectories:
# # Sample z from variational network.
# # Evalute likelihood of latent policy, and subpolicy.
# # Update policies using likelihoods.
self.set_epoch(counter)
self.iter = counter
############# (0) #############
# Get sample we're going to train on. Single sample as of now.
sample_traj, sample_action_seq, concatenated_traj, old_concatenated_traj = self.collect_inputs(i)
if sample_traj is not None:
############# (1) #############
# Sample latent variables from p(\zeta | \tau).
latent_z_indices, latent_b, variational_b_logprobabilities, variational_z_logprobabilities,\
variational_b_probabilities, variational_z_probabilities, kl_divergence, prior_loglikelihood = self.variational_policy.forward(torch.tensor(old_concatenated_traj).to(device).float(), self.epsilon)
########## (2) & (3) ##########
# Evaluate Log Likelihoods of actions and options as "Return" for Variational policy.
subpolicy_loglikelihoods, subpolicy_loglikelihood, subpolicy_entropy,\
latent_loglikelihood, latent_b_logprobabilities, latent_z_logprobabilities,\
latent_b_probabilities, latent_z_probabilities, latent_z_logprobability, latent_b_logprobability, \
learnt_subpolicy_loglikelihood, learnt_subpolicy_loglikelihoods, temporal_loglikelihoods = self.evaluate_loglikelihoods(sample_traj, sample_action_seq, concatenated_traj, latent_z_indices, latent_b)
if self.args.train:
if self.args.debug:
if self.iter%self.args.debug==0:
print("Embedding in Train Function.")
embed()
############# (3) #############
# Update latent policy Pi_z with Reinforce like update using LL as return.
self.new_update_policies(i, sample_action_seq, subpolicy_loglikelihoods, subpolicy_entropy, latent_b, latent_z_indices,\
variational_z_logprobabilities, variational_b_logprobabilities, variational_z_probabilities, variational_b_probabilities, kl_divergence, \
latent_z_logprobabilities, latent_b_logprobabilities, latent_z_probabilities, latent_b_probabilities, \
learnt_subpolicy_loglikelihood, learnt_subpolicy_loglikelihoods, learnt_subpolicy_loglikelihood+latent_loglikelihood, \
prior_loglikelihood, latent_loglikelihood, temporal_loglikelihoods)
# Update Plots.
# self.update_plots(counter, sample_map, loglikelihood)
self.update_plots(counter, i, learnt_subpolicy_loglikelihood, latent_loglikelihood, subpolicy_entropy,
sample_traj, latent_z_logprobability, latent_b_logprobability, kl_divergence, prior_loglikelihood)
# print("Latent LogLikelihood: ", latent_loglikelihood)
# print("Subpolicy LogLikelihood: ", learnt_subpolicy_loglikelihood)
print("#########################################")
else:
if self.args.data=='MIME' or self.args.data=='Roboturk' or self.args.data=='OrigRoboturk' or self.args.data=='FullRoboturk' or self.args.data=='Mocap':
pass
else:
print("#############################################")
print("Trajectory",i)
print("Predicted Z: \n", latent_z_indices.detach().cpu().numpy())
print("True Z : \n", np.array(self.dataset.Y_array[i][:self.args.traj_length]))
print("Latent B : \n", latent_b.detach().cpu().numpy())
# print("Variational Probs: \n", variational_z_probabilities.detach().cpu().numpy())
# print("Latent Probs : \n", latent_z_probabilities.detach().cpu().numpy())
print("Latent B Probs : \n", latent_b_probabilities.detach().cpu().numpy())
if self.args.subpolicy_model:
eval_encoded_logprobs = torch.zeros((latent_z_indices.shape[0]))
eval_orig_encoder_logprobs = torch.zeros((latent_z_indices.shape[0]))
torch_concat_traj = torch.tensor(concatenated_traj).to(device).float()
# For each timestep z in latent_z_indices, evaluate likelihood under pretrained encoder model.
for t in range(latent_z_indices.shape[0]):
eval_encoded_logprobs[t] = self.encoder_network.forward(torch_concat_traj, z_sample_to_evaluate=latent_z_indices[t])
_, eval_orig_encoder_logprobs[t], _, _ = self.encoder_network.forward(torch_concat_traj)
print("Encoder Loglikelihood:", eval_encoded_logprobs.detach().cpu().numpy())
print("Orig Encoder Loglikelihood:", eval_orig_encoder_logprobs.detach().cpu().numpy())
if self.args.debug:
embed()