in habitat_baselines/rl/ppo/ppo.py [0:0]
def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]:
advantages = self.get_advantages(rollouts)
value_loss_epoch = 0.0
action_loss_epoch = 0.0
dist_entropy_epoch = 0.0
for _e in range(self.ppo_epoch):
profiling_wrapper.range_push("PPO.update epoch")
data_generator = rollouts.recurrent_generator(
advantages, self.num_mini_batch
)
for batch in data_generator:
(
values,
action_log_probs,
dist_entropy,
_,
) = self._evaluate_actions(
batch["observations"],
batch["recurrent_hidden_states"],
batch["prev_actions"],
batch["masks"],
batch["actions"],
)
ratio = torch.exp(action_log_probs - batch["action_log_probs"])
surr1 = ratio * batch["advantages"]
surr2 = (
torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
* batch["advantages"]
)
action_loss = -(torch.min(surr1, surr2).mean())
if self.use_clipped_value_loss:
value_pred_clipped = batch["value_preds"] + (
values - batch["value_preds"]
).clamp(-self.clip_param, self.clip_param)
value_losses = (values - batch["returns"]).pow(2)
value_losses_clipped = (
value_pred_clipped - batch["returns"]
).pow(2)
value_loss = 0.5 * torch.max(
value_losses, value_losses_clipped
)
else:
value_loss = 0.5 * (batch["returns"] - values).pow(2)
value_loss = value_loss.mean()
dist_entropy = dist_entropy.mean()
self.optimizer.zero_grad()
total_loss = (
value_loss * self.value_loss_coef
+ action_loss
- dist_entropy * self.entropy_coef
)
self.before_backward(total_loss)
total_loss.backward()
self.after_backward(total_loss)
self.before_step()
self.optimizer.step()
self.after_step()
value_loss_epoch += value_loss.item()
action_loss_epoch += action_loss.item()
dist_entropy_epoch += dist_entropy.item()
profiling_wrapper.range_pop() # PPO.update epoch
num_updates = self.ppo_epoch * self.num_mini_batch
value_loss_epoch /= num_updates
action_loss_epoch /= num_updates
dist_entropy_epoch /= num_updates
return value_loss_epoch, action_loss_epoch, dist_entropy_epoch