def learner_loop()

in train/learner.py [0:0]


def learner_loop(flags, rank=0, barrier=None, gossip_buffer=None, stop_event=None):
    if flags.num_actors < flags.batch_size:
        logging.warn("Batch size is larger than number of actors.")
    assert (
        flags.batch_size % flags.inference_batch_size == 0
    ), "For now, inference_batch_size must divide batch_size"
    assert (
        flags.num_actors >= flags.inference_batch_size
    ), "Inference batch size must be <= number of actors"
    if flags.logdir:
        log_file_path = os.path.join(flags.logdir, "logs.tsv")
        logging.info(
            "%s logs to %s",
            "Appending" if os.path.exists(log_file_path) else "Writing",
            log_file_path,
        )
    else:
        logging.warn("--logdir not set. Not writing logs to file.")
    if not flags.cc_env_use_state_summary:
        raise NotImplementedError(
            "Setting cc_env_use_state_summary=False is currently not supported. The fetching "
            "of throughput and delay statistics through the states assumes that we are provided "
            "summary statistics, and would need to be updated to work without."
        )

    unroll_queue = queue.Queue(maxsize=1)
    log_queue = queue.Queue()

    # Inference model.
    model = make_train_model(flags)
    # Dummy (observation, reward, done)
    dummy_env_output = (
        np.zeros(flags.observation_length, dtype=np.float32),
        np.array(0, dtype=np.float32),
        np.array(True, dtype=np.bool),
    )
    dummy_env_output = nest.map(
        lambda a: torch.from_numpy(np.array(a)), dummy_env_output
    )

    with torch.no_grad():
        dummy_model_output, _ = model(
            last_actions=torch.zeros([1], dtype=torch.int64),
            env_outputs=nest.map(lambda t: t.unsqueeze(0), dummy_env_output),
            core_state=model.initial_state(1),
        )
        dummy_model_output = nest.map(lambda t: t.squeeze(0), dummy_model_output)

    model = model.to(device=flags.inference_device)

    # TODO: Decide if we really want that for simple tensors?
    actions = StructuredBuffer(torch.zeros([flags.num_actors], dtype=torch.int64))
    actor_run_ids = StructuredBuffer(torch.zeros([flags.num_actors], dtype=torch.int64))
    actor_infos = StructuredBuffer(
        dict(
            episode_step=torch.zeros([flags.num_actors], dtype=torch.int64),
            episode_return=torch.zeros([flags.num_actors]),
            cwnd_mean=torch.zeros([flags.num_actors]),
            delay_mean=torch.zeros([flags.num_actors]),
            throughput_mean=torch.zeros([flags.num_actors]),
            train_job_id=torch.zeros([flags.num_actors], dtype=torch.int64),
        )
    )

    # Agent states at the beginning of an unroll. Needs to be kept for learner.
    # A state is a tuple of two tensors of shape [num_actors, hidden_size + 1]):
    #   - Why two tensors? Because the LSTM cell's state contains both its output and its
    #     internal state.
    #   - Why +1? Because the reward is appended to the hidden layer being fed to the LSTM,
    #     and in `SimpleNet` the LSTM cell has the same output size as its input size.
    initial_states = model.initial_state(batch_size=flags.num_actors)
    first_agent_states = StructuredBuffer(initial_states)

    # Current agent states.
    agent_states = StructuredBuffer(copy.deepcopy(initial_states))

    rollouts = Rollouts(
        dict(
            last_actions=torch.zeros((), dtype=torch.int64),
            env_outputs=dummy_env_output,
            actor_outputs=dummy_model_output,
        ),
        unroll_length=flags.unroll_length,
        num_actors=flags.num_actors,
    )

    server = torchbeast.Server(flags.server_address, max_parallel_calls=4)

    def inference(actor_ids, run_ids, env_outputs):
        """
        Compute actions for a subset of actors, based on their observations.

        :param actor_ids: 1D tensor with the indices of actors whose actions
            are requested.
        :param run_ids: 1D tensor of same length as `actor_ids`, with the
            associated run IDs (used to deal with preemption / crashing: a
            fresh env is expected to provide a new run ID to be sure we do
            not accidentally re-use outdated data). Note that this mechanism
            is not actually needed by the current RL congestion control env.
        :param env_outputs: Tuple of observation (N x obs_dim), reward (N) and
            done (N) tensors, with N the size of `actor_ids`.
        """
        torch.set_grad_enabled(False)
        previous_run_ids = actor_run_ids.get(actor_ids)
        reset_indices = previous_run_ids != run_ids
        actor_run_ids.set(actor_ids, run_ids)

        actors_needing_reset = actor_ids[reset_indices]

        # Update new/restarted actors.
        # NB: this never happens because `runId` is always set to 0 in
        # `CongestionControlRPCEnv::makeCallRequest()``. This is working as
        # intended (calling `reset()` is not needed here).
        if actors_needing_reset.numel():
            logging.info("Actor ids needing reset: %s", actors_needing_reset.tolist())

            actor_infos.clear(actors_needing_reset)
            rollouts.reset(actors_needing_reset)
            actions.clear(actors_needing_reset)

            initial_agent_states = model.initial_state(actors_needing_reset.numel())
            first_agent_states.set(actors_needing_reset, initial_agent_states)
            agent_states.set(actors_needing_reset, initial_agent_states)

        obs, reward, done = env_outputs

        # Update logging stats at end of episode.
        done_ids = actor_ids[done]
        if done_ids.numel():
            # Do not log stats of zero-length episodes (typically these should
            # only happen on the very first episode, due to the `done` flag
            # being true without having any episode before).
            valid_done_ids = done_ids[actor_infos.get(done_ids)["episode_step"] > 0]
            log_queue.put((valid_done_ids, actor_infos.get(valid_done_ids)))
            # NB: `actor_infos.get()` returned a copy of the data, so it is ok
            # to clear it now.
            actor_infos.clear(done_ids)
            # Clear reward for agents that are done: it is meaningless as it is obtained with
            # the first observation, before the agent got a chance to take any action.
            reward[done] = 0.0
            # We only update the `train_job_id` field once (when an episode starts, which
            # is when `done` is True). This is because it remains the same throughout the
            # whole episode.
            extra_info = {"train_job_id": (obs[:, -1] * done).long()}
        else:
            extra_info = {}
        actor_infos.add(
            actor_ids,
            dict(
                episode_step=1,
                episode_return=reward,
                cwnd_mean=state.get_mean(obs, state.Field.CWND, dim=1)
                * flags.cc_env_norm_bytes
                / UDP_SEND_PACKET_LEN,
                delay_mean=state.get_mean(obs, state.Field.DELAY, dim=1)
                * flags.cc_env_norm_ms,
                throughput_mean=state.get_mean(obs, state.Field.THROUGHPUT, dim=1)
                * flags.cc_env_norm_bytes
                / 1024 ** 2,  # convert to Mbytes/s
                **extra_info,
            ),
        )

        last_actions = actions.get(actor_ids)
        prev_agent_states = agent_states.get(actor_ids)

        actor_outputs, new_agent_states = model(
            *nest.map(
                lambda t: t.to(flags.inference_device),
                (last_actions, env_outputs, prev_agent_states),
            )
        )
        actor_outputs, new_agent_states = nest.map(
            lambda t: t.cpu(), (actor_outputs, new_agent_states)
        )

        timestep = dict(
            last_actions=last_actions,
            env_outputs=env_outputs,
            actor_outputs=actor_outputs,
        )
        completed_ids, unrolls = rollouts.append(actor_ids, timestep)
        if completed_ids.numel():
            try:
                unroll_queue.put(
                    (completed_ids, unrolls, first_agent_states.get(completed_ids)),
                    timeout=5.0,
                )
            except queue.Closed:
                if server.running():
                    raise

            # Importantly, `first_agent_states` must contain the states *before* processing
            # the current batch of data, because in `rollouts` we always start from the last
            # item of the previous rollout (and we need the state before that one). This is
            # why this line must happen before the update to `agent_states` below.
            first_agent_states.set(completed_ids, agent_states.get(completed_ids))

        agent_states.set(actor_ids, new_agent_states)

        action = actor_outputs["action"]
        actions.set(actor_ids, action)
        return action

    server.bind("inference", inference, batch_size=flags.inference_batch_size)
    server.run()

    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        try:
            learn(
                model,
                executor,
                unroll_queue,
                log_queue,
                flags,
                rank,
                barrier,
                gossip_buffer,
                stop_event=stop_event,
            )
        except KeyboardInterrupt:
            print("Stopping ...")
        finally:
            unroll_queue.close()
            server.stop()