in level_replay/algo/ppo.py [0:0]
def update(self, rollouts):
advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
advantages = (advantages - advantages.mean()) / (
advantages.std() + 1e-5)
value_loss_epoch = 0
action_loss_epoch = 0
dist_entropy_epoch = 0
for e in range(self.ppo_epoch):
if self.actor_critic.is_recurrent:
data_generator = rollouts.recurrent_generator(
advantages, self.num_mini_batch)
else:
data_generator = rollouts.feed_forward_generator(
advantages, self.num_mini_batch)
for sample in data_generator:
obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \
adv_targ = sample
values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
obs_batch, recurrent_hidden_states_batch, masks_batch,
actions_batch)
ratio = torch.exp(action_log_probs -
old_action_log_probs_batch)
surr1 = ratio * adv_targ
surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param) * adv_targ
action_loss = -torch.min(surr1, surr2).mean()
value_pred_clipped = value_preds_batch + \
(values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
value_losses = (values - return_batch).pow(2)
value_losses_clipped = (
value_pred_clipped - return_batch).pow(2)
value_loss = 0.5 * torch.max(value_losses,
value_losses_clipped).mean()
self.optimizer.zero_grad()
loss = (value_loss*self.value_loss_coef + action_loss - dist_entropy*self.entropy_coef)
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
self.max_grad_norm)
self.optimizer.step()
value_loss_epoch += value_loss.item()
action_loss_epoch += action_loss.item()
dist_entropy_epoch += dist_entropy.item()
num_updates = self.ppo_epoch * self.num_mini_batch
value_loss_epoch /= num_updates
action_loss_epoch /= num_updates
dist_entropy_epoch /= num_updates
return value_loss_epoch, action_loss_epoch, dist_entropy_epoch