def run_a2c()

in salina_examples/rl/a2c/complete_with_async_eval/main.py [0:0]


def run_a2c(a2c_agent, logger, cfg):
    logger = instantiate_class(cfg.logger)
    a2c_agent.set_name("a2c_agent")

    assert cfg.algorithm.n_envs % cfg.algorithm.n_processes == 0

    acq_env_agent = AutoResetGymAgent(
        get_class(cfg.algorithm.env),
        get_arguments(cfg.algorithm.env),
        n_envs=int(cfg.algorithm.n_envs / cfg.algorithm.n_processes),
    )
    eval_env_agent = AutoResetGymAgent(
        get_class(cfg.algorithm.env),
        get_arguments(cfg.algorithm.env),
        n_envs=int(
            cfg.algorithm.evaluation.n_envs / cfg.algorithm.evaluation.n_processes
        ),
    )

    acq_a2c_agent = copy.deepcopy(a2c_agent)
    acq_agent = TemporalAgent(Agents(acq_env_agent, acq_a2c_agent))
    acq_remote_agent, acq_workspace = NRemoteAgent.create(
        acq_agent,
        num_processes=cfg.algorithm.n_processes,
        t=0,
        n_steps=cfg.algorithm.n_timesteps,
        stochastic=True,
        replay=False,
    )
    acq_remote_agent.seed(cfg.algorithm.env_seed)

    eval_a2c_agent = copy.deepcopy(a2c_agent)
    eval_agent = TemporalAgent(Agents(eval_env_agent, eval_a2c_agent))
    eval_remote_agent, eval_workspace = NRemoteAgent.create(
        eval_agent,
        num_processes=cfg.algorithm.evaluation.n_processes,
        t=0,
        n_steps=cfg.algorithm.evaluation.n_timesteps,
        stochastic=False,
        replay=False,
    )
    eval_remote_agent.seed(cfg.algorithm.evaluation.env_seed)

    ta2c_agent = TemporalAgent(a2c_agent).to(device=cfg.algorithm.device)

    optimizer_args = get_arguments(cfg.algorithm.optimizer)
    parameters = a2c_agent.parameters()
    optimizer = get_class(cfg.algorithm.optimizer)(parameters, **optimizer_args)

    epoch = 0

    for a in eval_remote_agent.get_by_name("a2c_agent"):
        a.load_state_dict(_state_dict(a2c_agent, "cpu"))
    eval_remote_agent._asynchronous_call(
        eval_workspace,
        t=0,
        n_steps=cfg.algorithm.evaluation.n_timesteps,
        replay=False,
        stochastic=False,
    )

    for epoch in range(cfg.algorithm.max_epochs):
        # Evaluation
        if not eval_remote_agent.is_running():
            done, creward = eval_workspace["env/done", "env/cumulated_reward"]
            creward = creward[done]
            if not creward.size()[0] == 0:
                logger.add_scalar("evaluation/reward", creward.mean().item(), epoch)
            for a in eval_remote_agent.get_by_name("a2c_agent"):
                a.load_state_dict(_state_dict(a2c_agent, "cpu"))
            eval_workspace.copy_n_last_steps(1)
            eval_remote_agent._asynchronous_call(
                eval_workspace,
                t=1,
                n_steps=cfg.algorithm.evaluation.n_timesteps - 1,
                replay=False,
                stochastic=False,
            )

        for a in acq_remote_agent.get_by_name("a2c_agent"):
            a.load_state_dict(_state_dict(a2c_agent, "cpu"))

        if epoch > 0:
            acq_workspace.copy_n_last_steps(1)
            acq_remote_agent(
                acq_workspace,
                t=1,
                replay=False,
                n_steps=cfg.algorithm.n_timesteps - 1,
                stochastic=True,
            )
        else:
            acq_remote_agent(
                acq_workspace,
                t=0,
                replay=False,
                n_steps=cfg.algorithm.n_timesteps,
                stochastic=True,
            )

        replay_workspace = Workspace(acq_workspace).to(cfg.algorithm.device)

        ta2c_agent(
            replay_workspace,
            t=0,
            n_steps=cfg.algorithm.n_timesteps,
            replay=True,
            stochastic=True,
        )

        critic, done, action_probs, reward, action = replay_workspace[
            "critic", "env/done", "action_probs", "env/reward", "action"
        ]

        burning = torch.zeros_like(reward)
        burning[cfg.algorithm.burning_timesteps :] = 1.0

        td = RLF.gae(
            critic, reward, done, cfg.algorithm.discount_factor, cfg.algorithm.gae
        )

        td_error = td ** 2
        critic_loss = (td_error * burning[:-1]).mean()

        entropy_loss = (
            torch.distributions.Categorical(action_probs).entropy() * burning
        ).mean()

        action_logp = _index(action_probs, action).log()
        a2c_loss = action_logp[:-1] * td.detach()
        a2c_loss = (a2c_loss * burning[:-1]).mean()

        logger.add_scalar("critic_loss", critic_loss.item(), epoch)
        logger.add_scalar("entropy_loss", entropy_loss.item(), epoch)
        logger.add_scalar("a2c_loss", a2c_loss.item(), epoch)

        loss = (
            -cfg.algorithm.entropy_coef * entropy_loss
            + cfg.algorithm.critic_coef * critic_loss
            - cfg.algorithm.a2c_coef * a2c_loss
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        creward = replay_workspace["env/cumulated_reward"]
        creward = creward[done]
        if creward.size()[0] > 0:
            logger.add_scalar("reward", creward.mean().item(), epoch)