def forward()

in Experiments/PolicyNetworks.py [0:0]


	def forward(self, input, epsilon, new_z_selection=True):

		# Input Format must be: Sequence_Length x 1 x Input_Size. 	
		format_input = input.view((input.shape[0], self.batch_size, self.input_size))
		hidden = None
		outputs, hidden = self.lstm(format_input)

		# Damping factor for probabilities to prevent washing out of bias. 
		variational_b_preprobabilities = self.termination_output_layer(outputs)*self.b_probability_factor

		# Predict Gaussian means and variances. 
		if self.args.mean_nonlinearity:
			mean_outputs = self.activation_layer(self.mean_output_layer(outputs))
		else:
			mean_outputs = self.mean_output_layer(outputs)
		# Still need a softplus activation for variance because needs to be positive. 
		variance_outputs = self.variance_factor*(self.variance_activation_layer(self.variances_output_layer(outputs))+self.variance_activation_bias) + epsilon
		# This should be a SET of distributions. 
		self.dists = torch.distributions.MultivariateNormal(mean_outputs, torch.diag_embed(variance_outputs))

		# Create variables for prior and probabilities.
		prior_values = torch.zeros_like(variational_b_preprobabilities).to(device).float()
		variational_b_probabilities = torch.zeros_like(variational_b_preprobabilities).to(device).float()
		variational_b_logprobabilities = torch.zeros_like(variational_b_preprobabilities).to(device).float()

		#######################################
		################ Set B ################
		#######################################

		# Set the first b to 1, and the time b was == 1. 		
		sampled_b = torch.zeros(input.shape[0]).to(device).int()
		sampled_b[0] = 1
		prev_time = 0

		for t in range(1,input.shape[0]):
			
			# Compute time since the last b occurred. 			
			delta_t = t-prev_time
			# Compute prior value. 
			prior_values[t] = self.get_prior_value(delta_t, max_limit=self.args.skill_length)

			# Construct probabilities.
			variational_b_probabilities[t,0,:] = self.batch_softmax_layer(variational_b_preprobabilities[t,0] + prior_values[t,0])
			variational_b_logprobabilities[t,0,:] = self.batch_logsoftmax_layer(variational_b_preprobabilities[t,0] + prior_values[t,0])
	
			# Now Implement Hard Restriction on Selection of B's. 
			if delta_t < self.min_skill_time:
				# Set B to 0. I.e. Continue. 
				# variational_b_probabilities[t,0,:] = variational_b_probabilities[t,0,:]*0
				# variational_b_probabilities[t,0,0] += 1
				
				sampled_b[t] = 0.

			elif (self.min_skill_time <= delta_t) and (delta_t < self.max_skill_time):		
				# Sample b. 			
				sampled_b[t] = self.select_epsilon_greedy_action(variational_b_probabilities[t:t+1], epsilon)

			elif self.max_skill_time <= delta_t:
				# Set B to 1. I.e. select new z. 
				sampled_b[t] = 1.

			# If b is 1, set the previous time to now. 
			if sampled_b[t]==1:
				prev_time = t				

		#######################################
		################ Set Z ################
		#######################################

		# Now set the z's. If greedy, just return the means. 
		if epsilon==0.:
			sampled_z_index = mean_outputs.squeeze(1)
		# If not greedy, then reparameterize. 
		else:
			# Whether to use reparametrization trick to retrieve the latent_z's.
			if self.args.train:
				noise = torch.randn_like(variance_outputs)

				# Instead of *sampling* the latent z from a distribution, construct using mu + sig * eps (random noise).
				sampled_z_index = mean_outputs + variance_outputs*noise
				# Ought to be able to pass gradients through this latent_z now.

				sampled_z_index = sampled_z_index.squeeze(1)

			# If evaluating, greedily get action.
			else:
				sampled_z_index = mean_outputs.squeeze(1)
		
		# Modify z's based on whether b was 1 or 0. This part should remain the same.		
		if new_z_selection:
			
			# Set initial b to 1. 
			sampled_b[0] = 1

			# Initial z is already trivially set. 
			for t in range(1,input.shape[0]):
				# If b_t==0, just use previous z. 
				# If b_t==1, sample new z. Here, we've cloned this from sampled_z's, so there's no need to do anything. 
				if sampled_b[t]==0:
					sampled_z_index[t] = sampled_z_index[t-1]		

		# Also compute logprobabilities of the latent_z's sampled from this net. 
		variational_z_logprobabilities = self.dists.log_prob(sampled_z_index.unsqueeze(1))
		variational_z_probabilities = None

		# Set standard distribution for KL. 
		standard_distribution = torch.distributions.MultivariateNormal(torch.zeros((self.output_size)).to(device),torch.eye((self.output_size)).to(device))
		# Compute KL.
		kl_divergence = torch.distributions.kl_divergence(self.dists, standard_distribution)

		# Prior loglikelihood
		prior_loglikelihood = standard_distribution.log_prob(sampled_z_index)

		if self.args.debug:
			print("#################################")
			print("Embedding in Variational Network.")
			embed()

		return sampled_z_index, sampled_b, variational_b_logprobabilities.squeeze(1), \
		 variational_z_logprobabilities, variational_b_probabilities.squeeze(1), variational_z_probabilities, kl_divergence, prior_loglikelihood