in tutorial/deprecated/tutorial_recurrent_policy/a2c.py [0:0]
def get_loss(self, trajectories):
# First, we want to compute the cumulated reward per trajectory
# The reward is a t+1 in each iteration (since it is btained after the aaction), so we use the '_reward' field in the trajectory
# The 'reward' field corresopnds to the reward at time t
reward = trajectories["_reward"]
# We get the mask that tells which transition is in a trajectory (1) or not (0)
mask = trajectories.mask()
# We remove the reward values that are not in the trajectories
reward = reward * mask
max_length = trajectories.lengths.max().item()
# Now, we want to compute the action probabilities over the trajectories such that we will be able to do 'backward'
action_probabilities = []
agent_state = trajectories["agent_state"][:, 0]
for t in range(max_length):
# since we are using an infinite env, we have to re-initialize the agent_state if we reach a new episode initial state
agent_state = masked_tensor(
agent_state,
trajectories["agent_state"][:, t],
trajectories["initial_state"][:, t],
)
agent_state, proba = self.learning_model(
agent_state, trajectories["frame"][:, t]
)
action_probabilities.append(
proba.unsqueeze(1)
) # We append the probability, and introduces the temporal dimension (2nde dimension)
action_probabilities = torch.cat(
action_probabilities, dim=1
) # Now, we have a B x T x n_actions tensor
# We compute the critic value for t=0 to T (i.e including the very last observation)
critic = []
agent_state = trajectories["agent_state"][:, 0]
for t in range(max_length):
# since we are using an infinite env, we have to re-initialize the agent_state if we reach a new episode initial state
agent_state = masked_tensor(
agent_state,
trajectories["agent_state"][:, t],
trajectories["initial_state"][:, t],
)
agent_state, b = self.critic_model(agent_state, trajectories["frame"][:, t])
critic.append(b.unsqueeze(1))
critic = torch.cat(critic + [b.unsqueeze(1)], dim=1).squeeze(
-1
) # Now, we have a B x (T+1) tensor
# We also need to compute the critic value at for the last observation of the trajectories (to compute the TD)
# It may be the last element of the trajectories (if episode is not finished), or on the last frame of the episode
idx = torch.arange(trajectories.n_elems())
_, last_critic = self.critic_model(
trajectories["_agent_state"][idx, trajectories.lengths - 1],
trajectories["_frame"][idx, trajectories.lengths - 1],
)
last_critic = last_critic.squeeze(-1)
critic[idx, trajectories.lengths] = last_critic
# We compute the temporal difference
target = (
reward
+ self.config["discount_factor"]
* (1 - trajectories["_done"].float())
* critic[:, 1:].detach()
)
td = critic[:, :-1] - target
critic_loss = td ** 2
# We sum the loss for each episode (considering the mask)
critic_loss = (critic_loss * mask).sum(1) / mask.sum(1)
# We average the loss over all the trajectories
avg_critic_loss = critic_loss.mean()
# We do the same on the reinforce loss
action_distribution = torch.distributions.Categorical(action_probabilities)
log_proba = action_distribution.log_prob(trajectories["action"])
a2c_loss = -log_proba * td.detach()
a2c_loss = (a2c_loss * mask).sum(1) / mask.sum(1)
avg_a2c_loss = a2c_loss.mean()
# We compute the entropy loss
entropy = action_distribution.entropy()
entropy = (entropy * mask).sum(1) / mask.sum(1)
avg_entropy = entropy.mean()
return DictTensor(
{
"critic_loss": avg_critic_loss,
"a2c_loss": avg_a2c_loss,
"entropy_loss": avg_entropy,
}
)