in ss_baselines/savi/ppo/ppo.py [0:0]
def update(self, rollouts):
advantages = self.get_advantages(rollouts)
value_loss_epoch = 0
action_loss_epoch = 0
dist_entropy_epoch = 0
for e in range(self.ppo_epoch):
data_generator = rollouts.recurrent_generator(
advantages, self.num_mini_batch
)
for sample in data_generator:
(
obs_batch,
recurrent_hidden_states_batch,
actions_batch,
prev_actions_batch,
value_preds_batch,
return_batch,
masks_batch,
old_action_log_probs_batch,
adv_targ,
external_memory,
external_memory_masks,
) = sample
# Reshape to do in a single forward pass for all steps
(
values,
action_log_probs,
dist_entropy,
_,
_,
) = self.actor_critic.evaluate_actions(
obs_batch,
recurrent_hidden_states_batch,
prev_actions_batch,
masks_batch,
actions_batch,
external_memory,
external_memory_masks,
)
ratio = torch.exp(
action_log_probs - old_action_log_probs_batch
)
surr1 = ratio * adv_targ
surr2 = (
torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
* adv_targ
)
action_loss = -torch.min(surr1, surr2).mean()
if self.use_clipped_value_loss:
value_pred_clipped = value_preds_batch + (
values - value_preds_batch
).clamp(-self.clip_param, self.clip_param)
value_losses = (values - return_batch).pow(2)
value_losses_clipped = (
value_pred_clipped - return_batch
).pow(2)
value_loss = (
0.5
* torch.max(value_losses, value_losses_clipped).mean()
)
else:
value_loss = 0.5 * (return_batch - values).pow(2).mean()
self.optimizer.zero_grad()
total_loss = (
value_loss * self.value_loss_coef
+ action_loss
- dist_entropy * self.entropy_coef
)
self.before_backward(total_loss)
total_loss.backward()
self.after_backward(total_loss)
self.before_step()
self.optimizer.step()
self.after_step()
value_loss_epoch += value_loss.item()
action_loss_epoch += action_loss.item()
dist_entropy_epoch += dist_entropy.item()
num_updates = self.ppo_epoch * self.num_mini_batch
value_loss_epoch /= num_updates
action_loss_epoch /= num_updates
dist_entropy_epoch /= num_updates
return value_loss_epoch, action_loss_epoch, dist_entropy_epoch