in Experiments/PolicyNetworks.py [0:0]
def forward(self, input, epsilon, new_z_selection=True, var_epsilon=0.001):
# Input Format must be: Sequence_Length x 1 x Input_Size.
format_input = input.view((input.shape[0], self.batch_size, self.input_size))
hidden = None
outputs, hidden = self.lstm(format_input)
# Damping factor for probabilities to prevent washing out of bias.
variational_b_preprobabilities = self.termination_output_layer(outputs)*self.b_probability_factor
# Add b continuation bias to the continuing option at every timestep.
variational_b_preprobabilities[:,0,0] += self.b_exploration_bias
variational_b_probabilities = self.batch_softmax_layer(variational_b_preprobabilities).squeeze(1)
variational_b_logprobabilities = self.batch_logsoftmax_layer(variational_b_preprobabilities).squeeze(1)
# Predict Gaussian means and variances.
if self.args.mean_nonlinearity:
mean_outputs = self.activation_layer(self.mean_output_layer(outputs))
else:
mean_outputs = self.mean_output_layer(outputs)
# Still need a softplus activation for variance because needs to be positive.
variance_outputs = self.variance_factor*(self.variance_activation_layer(self.variances_output_layer(outputs))+self.variance_activation_bias) + var_epsilon
# This should be a SET of distributions.
self.dists = torch.distributions.MultivariateNormal(mean_outputs, torch.diag_embed(variance_outputs))
sampled_b = self.select_epsilon_greedy_action(variational_b_probabilities, epsilon)
if epsilon==0.:
sampled_z_index = mean_outputs.squeeze(1)
else:
# Whether to use reparametrization trick to retrieve the latent_z's.
if self.args.reparam:
if self.args.train:
noise = torch.randn_like(variance_outputs)
# Instead of *sampling* the latent z from a distribution, construct using mu + sig * eps (random noise).
sampled_z_index = mean_outputs + variance_outputs*noise
# Ought to be able to pass gradients through this latent_z now.
sampled_z_index = sampled_z_index.squeeze(1)
# If evaluating, greedily get action.
else:
sampled_z_index = mean_outputs.squeeze(1)
else:
sampled_z_index = self.dists.sample().squeeze(1)
if new_z_selection:
# Set initial b to 1.
sampled_b[0] = 1
# Initial z is already trivially set.
for t in range(1,input.shape[0]):
# If b_t==0, just use previous z.
# If b_t==1, sample new z. Here, we've cloned this from sampled_z's, so there's no need to do anything.
if sampled_b[t]==0:
sampled_z_index[t] = sampled_z_index[t-1]
# Also compute logprobabilities of the latent_z's sampled from this net.
variational_z_logprobabilities = self.dists.log_prob(sampled_z_index.unsqueeze(1))
variational_z_probabilities = None
# Set standard distribution for KL.
standard_distribution = torch.distributions.MultivariateNormal(torch.zeros((self.output_size)).to(device),torch.eye((self.output_size)).to(device))
# Compute KL.
kl_divergence = torch.distributions.kl_divergence(self.dists, standard_distribution)
# Prior loglikelihood
prior_loglikelihood = standard_distribution.log_prob(sampled_z_index)
# if self.args.debug:
# print("#################################")
# print("Embedding in Variational Network.")
# embed()
return sampled_z_index, sampled_b, variational_b_logprobabilities,\
variational_z_logprobabilities, variational_b_probabilities, variational_z_probabilities, kl_divergence, prior_loglikelihood