in gala/gala_a2c.py [0:0]
def update(self, rollouts):
obs_shape = rollouts.obs.size()[2:]
action_shape = rollouts.actions.size()[-1]
num_steps, num_processes, _ = rollouts.rewards.size()
values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
rollouts.obs[:-1].view(-1, *obs_shape),
rollouts.recurrent_hidden_states[0].view(
-1, self.actor_critic.recurrent_hidden_state_size),
rollouts.masks[:-1].view(-1, 1),
rollouts.actions.view(-1, action_shape))
values = values.view(num_steps, num_processes, 1)
action_log_probs = action_log_probs.view(num_steps, num_processes, 1)
advantages = rollouts.returns[:-1] - values
value_loss = advantages.pow(2).mean()
action_loss = -(advantages.detach() * action_log_probs).mean()
self.optimizer.zero_grad()
(value_loss * self.value_loss_coef + action_loss -
dist_entropy * self.entropy_coef).backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
self.max_grad_norm)
self.optimizer.step()
# Local-Gossip
if self.gossip_buffer is not None:
self.gossip_buffer.write_message(self.rank, self.actor_critic)
self.gossip_buffer.aggregate_message(self.rank, self.actor_critic)
return value_loss.item(), action_loss.item(), dist_entropy.item()