def actor_learner()

in main.py [0:0]


def actor_learner(args, rank, barrier, device, gossip_buffer):
    """ Single Actor-Learner Process """

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.cuda.set_device(device)
    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # (Hack) Import here to ensure OpenAI-gym envs only run on the CPUs
    # corresponding to the processes' affinity
    from gala import utils
    from gala.envs import make_vec_envs
    # Make envs
    envs = make_vec_envs(args.env_name, args.seed, args.num_procs_per_learner,
                         args.gamma, args.log_dir, device, False,
                         rank=rank)

    # Initialize actor_critic
    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    # Initialize agent
    agent = GALA_A2C(
        actor_critic,
        args.value_loss_coef,
        args.entropy_coef,
        lr=args.lr,
        eps=args.eps,
        alpha=args.alpha,
        max_grad_norm=args.max_grad_norm,
        rank=rank,
        gossip_buffer=gossip_buffer
    )

    rollouts = RolloutStorage(args.num_steps_per_update,
                              args.num_procs_per_learner,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)
    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    # Synchronize agents before starting training
    barrier.wait()
    print('%s: barrier passed' % rank)

    # Start training
    start = time.time()
    num_updates = int(args.num_env_steps) // (
        args.num_steps_per_update
        * args.num_procs_per_learner
        * args.num_learners)
    save_interval = int(args.save_interval) // (
        args.num_steps_per_update
        * args.num_procs_per_learner
        * args.num_learners)

    for j in range(num_updates):

        # Decrease learning rate linearly
        if args.use_linear_lr_decay:
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates, args.lr)
        # --/

        # Step through environment
        # --
        for step in range(args.num_steps_per_update):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)
        # --/

        # Update parameters
        # --
        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()
        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()
        # --/

        # Save every "save_interval" local environment steps (or last update)
        if (j % save_interval == 0
                or j == num_updates - 1) and args.save_dir != '':
            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(args.save_dir,
                            '%s.%.3d.pt' % (rank, j // save_interval)))
        # --/

        # Log every "log_interval" local environment steps
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            num_steps = (j + 1) * args.num_procs_per_learner \
                * args.num_steps_per_update
            end = time.time()
            print(('{}: Updates {}, num timesteps {}, FPS {} ' +
                   '\n {}: Last {} training episodes: ' +
                   'mean/median reward {:.1f}/{:.1f}, ' +
                   'min/max reward {:.1f}/{:.1f}\n').
                  format(rank, j, num_steps,
                         int(num_steps / (end - start)), rank,
                         len(episode_rewards),
                         np.mean(episode_rewards),
                         np.median(episode_rewards),
                         np.min(episode_rewards),
                         np.max(episode_rewards),
                         dist_entropy, value_loss, action_loss
                         ))