salina_examples/rl/ppo_brax/ppo.py [213:250]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                reward = reward * cfg.algorithm.reward_scaling
                gae = RLF.gae(
                    critic,
                    reward,
                    done,
                    cfg.algorithm.discount_factor,
                    cfg.algorithm.gae,
                ).detach()
                action_lp = miniworkspace["action_logprobs"]
                ratio = action_lp - old_action_lp
                ratio = ratio.exp()
                ratio = ratio[:-1]
                clip_adv = (
                    torch.clamp(
                        ratio,
                        1 - cfg.algorithm.clip_ratio,
                        1 + cfg.algorithm.clip_ratio,
                    )
                    * gae
                )
                loss_policy = -(torch.min(ratio * gae, clip_adv)).mean()

                td0 = RLF.temporal_difference(
                    critic, reward, done, cfg.algorithm.discount_factor
                )
                loss_critic = (td0 ** 2).mean()
                optimizer_critic.zero_grad()
                optimizer_policy.zero_grad()
                (loss_policy + loss_critic).backward()
                n = clip_grad(action_agent.parameters(), cfg.algorithm.clip_grad)
                optimizer_policy.step()
                optimizer_critic.step()
                logger.add_scalar("monitor/grad_norm_policy", n.item(), iteration)
                logger.add_scalar("loss/policy", loss_policy.item(), iteration)
                logger.add_scalar("loss/critic", loss_critic.item(), iteration)
                logger.add_scalar("monitor/grad_norm_critic", n.item(), iteration)
                iteration += 1
        epoch += 1
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



salina_examples/rl/ppo_brax_transformer/ppo.py [195:232]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                reward = reward * cfg.algorithm.reward_scaling
                gae = RLF.gae(
                    critic,
                    reward,
                    done,
                    cfg.algorithm.discount_factor,
                    cfg.algorithm.gae,
                ).detach()
                action_lp = miniworkspace["action_logprobs"]
                ratio = action_lp - old_action_lp
                ratio = ratio.exp()
                ratio = ratio[:-1]
                clip_adv = (
                    torch.clamp(
                        ratio,
                        1 - cfg.algorithm.clip_ratio,
                        1 + cfg.algorithm.clip_ratio,
                    )
                    * gae
                )
                loss_policy = -(torch.min(ratio * gae, clip_adv)).mean()

                td0 = RLF.temporal_difference(
                    critic, reward, done, cfg.algorithm.discount_factor
                )
                loss_critic = (td0 ** 2).mean()
                optimizer_critic.zero_grad()
                optimizer_policy.zero_grad()
                (loss_policy + loss_critic).backward()
                n = clip_grad(action_agent.parameters(), cfg.algorithm.clip_grad)
                optimizer_policy.step()
                optimizer_critic.step()
                logger.add_scalar("monitor/grad_norm_policy", n.item(), iteration)
                logger.add_scalar("loss/policy", loss_policy.item(), iteration)
                logger.add_scalar("loss/critic", loss_critic.item(), iteration)
                logger.add_scalar("monitor/grad_norm_critic", n.item(), iteration)
                iteration += 1
        epoch += 1
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



