def train_step()

in rlmeta/agents/ppo/ppo_agent.py [0:0]


    def train_step(self, batch: NestedTensor) -> Dict[str, float]:
        device = next(self.model.parameters()).device
        batch = nested_utils.map_nested(lambda x: x.to(device), batch)
        self.optimizer.zero_grad()

        action = batch["action"]
        action_logpi = batch["logpi"]
        adv = batch["gae"]
        ret = batch["return"]
        logpi, v = self.model_forward(batch)

        if self.value_clip:
            # Value clip
            v_batch = batch["v"]
            v_clamp = v_batch + (v - v_batch).clamp(-self.eps_clip,
                                                    self.eps_clip)
            vf1 = (ret - v).square()
            vf2 = (ret - v_clamp).square()
            value_loss = torch.max(vf1, vf2).mean() * 0.5
        else:
            value_loss = (ret - v).square().mean() * 0.5

        entropy = -(logpi.exp() * logpi).sum(dim=-1).mean()
        entropy_loss = -self.entropy_ratio * entropy

        if self.advantage_normalization:
            # Advantage normalization
            std, mean = torch.std_mean(adv, unbiased=False)
            adv = (adv - mean) / std

        # Policy clip
        logpi = logpi.gather(dim=-1, index=action)
        ratio = (logpi - action_logpi).exp()
        ratio_clamp = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip)
        surr1 = ratio * adv
        surr2 = ratio_clamp * adv
        policy_loss = -torch.min(surr1, surr2).mean()

        loss = policy_loss + value_loss + entropy_loss
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(),
                                             self.grad_clip)
        self.optimizer.step()

        return {
            "return": ret.detach().mean().item(),
            "entropy": entropy.detach().mean().item(),
            "policy_ratio": ratio.detach().mean().item(),
            "policy_loss": policy_loss.detach().mean().item(),
            "value_loss": value_loss.detach().mean().item(),
            "entropy_loss ": entropy_loss.detach().mean().item(),
            "loss": loss.detach().mean().item(),
            "grad_norm": grad_norm.detach().mean().item(),
        }