def update()

in exploring_exploration/algo/ppo.py [0:0]


    def update(self, rollouts):
        advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        value_loss_epoch = 0
        action_loss_epoch = 0
        dist_entropy_epoch = 0

        for e in range(self.ppo_epoch):
            if self.actor_critic.is_recurrent:
                data_generator = rollouts.recurrent_generator(
                    advantages, self.num_mini_batch
                )
            else:
                data_generator = rollouts.feed_forward_generator(
                    advantages, self.num_mini_batch
                )

            for sample in data_generator:
                (
                    obs_im_batch,
                    obs_sm_batch,
                    obs_lm_batch,
                    recurrent_hidden_states_batch,
                    actions_batch,
                    value_preds_batch,
                    return_batch,
                    masks_batch,
                    collisions_batch,
                    old_action_log_probs_batch,
                    adv_targ,
                    T,
                    N,
                ) = sample

                # ======================= Forward pass ========================
                encoder_inputs = [obs_im_batch]
                if self.encoder_type == "rgb+map":
                    encoder_inputs += [obs_sm_batch, obs_lm_batch]
                obs_feats = self.encoder(*encoder_inputs)
                policy_inputs = {"features": obs_feats}
                prev_actions = torch.zeros_like(actions_batch.view(T, N, 1))
                prev_actions[1:] = actions_batch.view(T, N, 1)[:-1]
                prev_actions = prev_actions.view(T * N, 1)
                prev_collisions = collisions_batch
                if self.use_action_embedding:
                    policy_inputs["actions"] = prev_actions.long()
                if self.use_collision_embedding:
                    policy_inputs["collisions"] = prev_collisions.long()
                # Reshape to do in a single forward pass for all steps
                policy_outputs = self.actor_critic.evaluate_actions(
                    policy_inputs,
                    recurrent_hidden_states_batch,
                    masks_batch,
                    actions_batch,
                )
                values, action_log_probs, dist_entropy, _ = policy_outputs
                # ===================== Compute PPO losses ====================
                # Clipped surrogate loss
                ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
                surr1 = ratio * adv_targ
                surr2 = (
                    torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param)
                    * adv_targ
                )
                action_loss = -torch.min(surr1, surr2).mean()
                # Value function loss
                if self.use_clipped_value_loss:
                    value_pred_clipped = value_preds_batch + (
                        values - value_preds_batch
                    ).clamp(-self.clip_param, self.clip_param)
                    value_losses = (values - return_batch).pow(2)
                    value_losses_clipped = (value_pred_clipped - return_batch).pow(2)
                    value_loss = (
                        0.5 * torch.max(value_losses, value_losses_clipped).mean()
                    )
                else:
                    value_loss = 0.5 * F.mse_loss(return_batch, values)
                # ======================= Backward pass =======================
                self.optimizer.zero_grad()
                (
                    value_loss * self.value_loss_coef
                    + action_loss
                    - dist_entropy * self.entropy_coef
                ).backward()
                nn.utils.clip_grad_norm_(
                    chain(self.encoder.parameters(), self.actor_critic.parameters()),
                    self.max_grad_norm,
                )
                self.optimizer.step()
                # ===================== Update statistics =====================
                value_loss_epoch += value_loss.item()
                action_loss_epoch += action_loss.item()
                dist_entropy_epoch += dist_entropy.item()

        num_updates = self.ppo_epoch * self.num_mini_batch
        value_loss_epoch /= num_updates
        action_loss_epoch /= num_updates
        dist_entropy_epoch /= num_updates

        losses = {}
        losses["value_loss"] = value_loss_epoch
        losses["action_loss"] = action_loss_epoch
        losses["dist_entropy"] = dist_entropy_epoch
        return losses