in salina_examples/rl/a2c/multi_cpus/main.py [0:0]
def run_a2c(cfg):
# Build the logger
logger = instantiate_class(cfg.logger)
# Get info on the environment
env = instantiate_class(cfg.algorithm.env)
observation_size = env.observation_space.shape[0]
n_actions = env.action_space.n
del env
assert cfg.algorithm.n_envs % cfg.algorithm.n_processes == 0
# Create the agents
acq_env_agent = AutoResetGymAgent(
get_class(cfg.algorithm.env),
get_arguments(cfg.algorithm.env),
n_envs=int(cfg.algorithm.n_envs / cfg.algorithm.n_processes),
)
prob_agent = ProbAgent(
observation_size, cfg.algorithm.architecture.hidden_size, n_actions
)
acq_prob_agent = copy.deepcopy(prob_agent)
acq_action_agent = ActionAgent()
acq_agent = TemporalAgent(Agents(acq_env_agent, acq_prob_agent, acq_action_agent))
acq_remote_agent, acq_workspace = NRemoteAgent.create(
acq_agent,
num_processes=cfg.algorithm.n_processes,
t=0,
n_steps=cfg.algorithm.n_timesteps,
stochastic=True,
)
acq_remote_agent.seed(cfg.algorithm.env_seed)
critic_agent = CriticAgent(
observation_size, cfg.algorithm.architecture.hidden_size, n_actions
)
tprob_agent = TemporalAgent(prob_agent)
tcritic_agent = TemporalAgent(critic_agent)
# 7) Confgure the optimizer over the a2c agent
optimizer_args = get_arguments(cfg.algorithm.optimizer)
parameters = nn.Sequential(prob_agent, critic_agent).parameters()
optimizer = get_class(cfg.algorithm.optimizer)(parameters, **optimizer_args)
# 8) Training loop
epoch = 0
for epoch in range(cfg.algorithm.max_epochs):
pagent = acq_remote_agent.get_by_name("prob_agent")
for a in pagent:
a.load_state_dict(prob_agent.state_dict())
if epoch > 0:
acq_workspace.copy_n_last_steps(1)
acq_remote_agent(
acq_workspace,
t=1,
n_steps=cfg.algorithm.n_timesteps - 1,
stochastic=True,
)
else:
acq_remote_agent(
acq_workspace, t=0, n_steps=cfg.algorithm.n_timesteps, stochastic=True
)
replay_workspace = Workspace(acq_workspace)
tprob_agent(replay_workspace, t=0, n_steps=cfg.algorithm.n_timesteps)
tcritic_agent(replay_workspace, t=0, n_steps=cfg.algorithm.n_timesteps)
critic, done, action_probs, reward, action = replay_workspace[
"critic", "env/done", "action_probs", "env/reward", "action"
]
target = reward[1:] + cfg.algorithm.discount_factor * critic[1:].detach() * (
1 - done[1:].float()
)
td = target - critic[:-1]
td_error = td ** 2
critic_loss = td_error.mean()
entropy_loss = torch.distributions.Categorical(action_probs).entropy().mean()
action_logp = _index(action_probs, action).log()
a2c_loss = action_logp[:-1] * td.detach()
a2c_loss = a2c_loss.mean()
logger.add_scalar("critic_loss", critic_loss.item(), epoch)
logger.add_scalar("entropy_loss", entropy_loss.item(), epoch)
logger.add_scalar("a2c_loss", a2c_loss.item(), epoch)
loss = (
-cfg.algorithm.entropy_coef * entropy_loss
+ cfg.algorithm.critic_coef * critic_loss
- cfg.algorithm.a2c_coef * a2c_loss
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
creward = replay_workspace["env/cumulated_reward"]
creward = creward[done]
if creward.size()[0] > 0:
logger.add_scalar("reward", creward.mean().item(), epoch)