in rlmeta/agents/ppo/ppo_agent.py [0:0]
def train_step(self, batch: NestedTensor) -> Dict[str, float]:
device = next(self.model.parameters()).device
batch = nested_utils.map_nested(lambda x: x.to(device), batch)
self.optimizer.zero_grad()
action = batch["action"]
action_logpi = batch["logpi"]
adv = batch["gae"]
ret = batch["return"]
logpi, v = self.model_forward(batch)
if self.value_clip:
# Value clip
v_batch = batch["v"]
v_clamp = v_batch + (v - v_batch).clamp(-self.eps_clip,
self.eps_clip)
vf1 = (ret - v).square()
vf2 = (ret - v_clamp).square()
value_loss = torch.max(vf1, vf2).mean() * 0.5
else:
value_loss = (ret - v).square().mean() * 0.5
entropy = -(logpi.exp() * logpi).sum(dim=-1).mean()
entropy_loss = -self.entropy_ratio * entropy
if self.advantage_normalization:
# Advantage normalization
std, mean = torch.std_mean(adv, unbiased=False)
adv = (adv - mean) / std
# Policy clip
logpi = logpi.gather(dim=-1, index=action)
ratio = (logpi - action_logpi).exp()
ratio_clamp = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip)
surr1 = ratio * adv
surr2 = ratio_clamp * adv
policy_loss = -torch.min(surr1, surr2).mean()
loss = policy_loss + value_loss + entropy_loss
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(),
self.grad_clip)
self.optimizer.step()
return {
"return": ret.detach().mean().item(),
"entropy": entropy.detach().mean().item(),
"policy_ratio": ratio.detach().mean().item(),
"policy_loss": policy_loss.detach().mean().item(),
"value_loss": value_loss.detach().mean().item(),
"entropy_loss ": entropy_loss.detach().mean().item(),
"loss": loss.detach().mean().item(),
"grad_norm": grad_norm.detach().mean().item(),
}