def get_loss()

in rlalgos/deprecated/ppo/discrete_ppo.py [0:0]


    def get_loss(self, trajectories):
        device = self.config["learner_device"]
        trajectories = trajectories.to(device)
        max_length = trajectories.lengths.max().item()
        assert trajectories.lengths.eq(max_length).all()
        actions = trajectories["action"]
        actions_probabilities = trajectories["action_probabilities"]
        reward = trajectories["_reward"]
        frame = trajectories["frame"]
        last_action = trajectories["last_action"]
        done = trajectories["_done"].float()
        # Re compute model on trajectories
        n_action_scores = []
        n_values = []
        hidden_state = trajectories["agent_state"][:, 0]
        for T in range(max_length):
            hidden_state = masked_tensor(
                hidden_state,
                trajectories["agent_state"][:, T],
                trajectories["initial_state"][:, T],
            )
            _as, _v, hidden_state = self.learning_model(
                hidden_state, frame[:, T], last_action[:, T]
            )
            n_action_scores.append(_as.unsqueeze(1))
            n_values.append(_v.unsqueeze(1))
        n_action_scores = torch.cat(n_action_scores, dim=1)

        n_values = torch.cat(
            [*n_values, torch.zeros(trajectories.n_elems(), 1, 1).to(device)], dim=1
        ).squeeze(-1)

        # Compute value function for last state
        _idx = torch.arange(trajectories.n_elems()).to(device)
        _hidden_state = (
            hidden_state.detach()
        )  # trajectories["_agent_state"][_idx, trajectories.lengths - 1]
        _frame = trajectories["_frame"][_idx, trajectories.lengths - 1]
        _last_action = trajectories["_last_action"][_idx, trajectories.lengths - 1]
        _, _v, _ = self.learning_model(_hidden_state, _frame, _last_action)
        n_values[_idx, trajectories.lengths] = _v.squeeze(-1)

        advantage = self.get_gae(
            trajectories,
            n_values,
            discount_factor=self.config["discount_factor"],
            _lambda=self.config["gae_lambda"],
        )

        value_loss = advantage ** 2
        avg_value_loss = value_loss.mean()

        n_action_probabilities = torch.softmax(n_action_scores, dim=2)
        n_action_distribution = torch.distributions.Categorical(n_action_probabilities)
        log_a = torch.distributions.Categorical(actions_probabilities).log_prob(actions)
        log_na = n_action_distribution.log_prob(actions)
        ratios = torch.exp(log_na - log_a)
        surr1 = ratios * advantage
        surr2 = (
            torch.clamp(
                ratios, 1 - self.config["eps_clip"], 1 - self.config["eps_clip"]
            )
            * advantage
        )

        ppo_loss = torch.min(surr1, surr2)
        avg_ppo_loss = ppo_loss.mean()

        entropy_loss = n_action_distribution.entropy()
        avg_entropy_loss = entropy_loss.mean()

        dt = DictTensor(
            {
                "entropy_loss": avg_entropy_loss,
                "ppo_loss": avg_ppo_loss,
                "value_loss": avg_value_loss,
            }
        )
        return dt