def get_loss()

in tutorial/deprecated/tutorial_a2c_with_infinite_env/a2c.py [0:0]


    def get_loss(self, trajectories):
        # First, we want to compute the cumulated reward per trajectory
        # The reward is a t+1 in each iteration (since it is btained after the aaction), so we use the '_reward' field in the trajectory
        # The 'reward' field corresopnds to the reward at time t
        reward = trajectories["_reward"]

        # We get the mask that tells which transition is in a trajectory (1) or not (0)
        mask = trajectories.mask()

        # We remove the reward values that are not in the trajectories
        reward = reward * mask
        max_length = trajectories.lengths.max().item()
        # Now, we want to compute the action probabilities over the trajectories such that we will be able to do 'backward'
        action_probabilities = []
        for t in range(max_length):
            proba = self.learning_model(trajectories["frame"][:, t])
            action_probabilities.append(
                proba.unsqueeze(1)
            )  # We append the probability, and introduces the temporal dimension (2nde dimension)
        action_probabilities = torch.cat(
            action_probabilities, dim=1
        )  # Now, we have a B x T x n_actions tensor

        # We compute the critic value for t=0 to T (i.e including the very last observation)
        critic = []
        for t in range(max_length):
            b = self.critic_model(trajectories["frame"][:, t])
            critic.append(b.unsqueeze(1))
        critic = torch.cat(critic + [b.unsqueeze(1)], dim=1).squeeze(
            -1
        )  # Now, we have a B x (T+1) tensor
        # We also need to compute the critic value at for the last observation of the trajectories (to compute the TD)
        # It may be the last element of the trajectories (if episode is not finished), or on the last frame of the episode
        idx = torch.arange(trajectories.n_elems())
        last_critic = self.critic_model(
            trajectories["_frame"][idx, trajectories.lengths - 1]
        ).squeeze(-1)
        critic[idx, trajectories.lengths] = last_critic

        # We compute the temporal difference
        target = (
            reward
            + self.config["discount_factor"]
            * (1 - trajectories["_done"].float())
            * critic[:, 1:].detach()
        )
        td = critic[:, :-1] - target

        critic_loss = td ** 2
        # We sum the loss for each episode (considering the mask)
        critic_loss = (critic_loss * mask).sum(1) / mask.sum(1)
        # We average the loss over all the trajectories
        avg_critic_loss = critic_loss.mean()

        # We do the same on the reinforce loss
        action_distribution = torch.distributions.Categorical(action_probabilities)
        log_proba = action_distribution.log_prob(trajectories["action"])
        a2c_loss = -log_proba * td.detach()
        a2c_loss = (a2c_loss * mask).sum(1) / mask.sum(1)
        avg_a2c_loss = a2c_loss.mean()

        # We compute the entropy loss
        entropy = action_distribution.entropy()
        entropy = (entropy * mask).sum(1) / mask.sum(1)
        avg_entropy = entropy.mean()

        return DictTensor(
            {
                "critic_loss": avg_critic_loss,
                "a2c_loss": avg_a2c_loss,
                "entropy_loss": avg_entropy,
            }
        )