in salina_examples/rl/ppo_continuous/ppo.py [0:0]
def run_ppo(ppo_action_agent, ppo_critic_agent, logger, cfg):
ppo_action_agent.set_name("ppo_action")
env_agent = AutoResetGymAgent(
get_class(cfg.algorithm.env),
get_arguments(cfg.algorithm.env),
n_envs=int(cfg.algorithm.n_envs / cfg.algorithm.n_processes),
)
acq_ppo_action = copy.deepcopy(ppo_action_agent)
acq_agent = Agents(env_agent, acq_ppo_action)
acq_agent = TemporalAgent(acq_agent)
acq_remote_agent, acq_workspace = NRemoteAgent.create(
acq_agent,
num_processes=cfg.algorithm.n_processes,
t=0,
n_steps=cfg.algorithm.n_timesteps,
stochastic=True,
action_variance=0.0,
replay=False,
)
acq_remote_agent.seed(cfg.algorithm.env_seed)
tppo_action_agent = TemporalAgent(ppo_action_agent).to(cfg.algorithm.device)
tppo_critic_agent = TemporalAgent(ppo_critic_agent).to(cfg.algorithm.device)
optimizer_args = get_arguments(cfg.algorithm.optimizer)
parameters = ppo_action_agent.parameters()
optimizer_action = get_class(cfg.algorithm.optimizer)(parameters, **optimizer_args)
parameters = ppo_critic_agent.parameters()
optimizer_critic = get_class(cfg.algorithm.optimizer)(parameters, **optimizer_args)
epoch = 0
iteration = 0
for epoch in range(cfg.algorithm.max_epochs):
for a in acq_remote_agent.get_by_name("ppo_action"):
a.load_state_dict(_state_dict(ppo_action_agent, "cpu"))
if epoch > 0:
acq_workspace.copy_n_last_steps(1)
acq_remote_agent(
acq_workspace,
t=1,
replay=False,
n_steps=cfg.algorithm.n_timesteps - 1,
action_variance=cfg.algorithm.action_variance,
)
else:
acq_remote_agent(
acq_workspace,
t=0,
replay=False,
n_steps=cfg.algorithm.n_timesteps,
action_variance=cfg.algorithm.action_variance,
)
replay_workspace = Workspace(acq_workspace).to(cfg.algorithm.device)
with torch.no_grad():
tppo_critic_agent(replay_workspace, t=0, n_steps=cfg.algorithm.n_timesteps)
critic, done, reward, action = replay_workspace[
"critic", "env/done", "env/reward", "action"
]
gae = RLF.gae(
critic, reward, done, cfg.algorithm.discount_factor, cfg.algorithm.gae
).detach()
old_action_lp = replay_workspace["action_logprobs"]
for _ in range(cfg.algorithm.pi_epochs):
replay_workspace.zero_grad()
tppo_action_agent(
replay_workspace,
replay=True,
action_variance=cfg.algorithm.action_variance,
t=0,
n_steps=cfg.algorithm.n_timesteps,
)
action_lp = replay_workspace["action_logprobs"]
entropy = replay_workspace["entropy"]
ratio = (action_lp - old_action_lp).exp()
ratio = ratio[:-1]
clip_adv = (
torch.clamp(
ratio, 1 - cfg.algorithm.clip_ratio, 1 + cfg.algorithm.clip_ratio
)
* gae
)
loss_pi = -(torch.min(ratio * gae, clip_adv)).mean()
loss = loss_pi - cfg.algorithm.entropy_coef * entropy.mean()
optimizer_action.zero_grad()
loss.backward()
if cfg.algorithm.clip_grad > 0:
n = torch.nn.utils.clip_grad_norm_(
tppo_action_agent.parameters(), cfg.algorithm.clip_grad
)
logger.add_scalar("monitor/grad_norm_action", n.item(), iteration)
optimizer_action.step()
logger.add_scalar("loss_pi", loss_pi.item(), iteration)
logger.add_scalar("loss_entropy", entropy.mean().item(), iteration)
iteration += 1
for _ in range(cfg.algorithm.v_epochs):
replay_workspace.zero_grad()
tppo_critic_agent(replay_workspace, t=0, n_steps=cfg.algorithm.n_timesteps)
critic = replay_workspace["critic"]
gae = RLF.gae(
critic, reward, done, cfg.algorithm.discount_factor, cfg.algorithm.gae
)
optimizer_critic.zero_grad()
loss = (gae ** 2).mean() * cfg.algorithm.critic_coef
logger.add_scalar("loss_v", loss.item(), iteration)
loss.backward()
if cfg.algorithm.clip_grad > 0:
n = torch.nn.utils.clip_grad_norm_(
tppo_critic_agent.parameters(), cfg.algorithm.clip_grad
)
logger.add_scalar("monitor/grad_norm_critic", n.item(), iteration)
optimizer_critic.step()
iteration += 1
# Compute the cumulated reward on final_state
creward = replay_workspace["env/cumulated_reward"]
creward = creward[done]
if creward.size()[0] > 0:
logger.add_scalar("reward", creward.mean().item(), epoch)