def train()

in train/train.py [0:0]


def train(flags):
    flags.mode = "train"
    flags.cc_env_mode = "remote"

    if torch.cuda.is_available():
        flags.learner_device = "cuda:0"
        flags.inference_device = "cuda:1"

    # For GALA
    proc_manager = mp.Manager()
    barrier = None
    shared_gossip_buffer = None

    # In GALA mode, start multiple replicas of the torchbeast-pantheon setup.
    num_agents = 1
    if flags.num_gala_agents > 1:
        num_agents = flags.num_gala_agents
        logging.info("In GALA mode, will start {} agents".format(num_agents))
        barrier = proc_manager.Barrier(num_agents)

        # Shared-gossip-buffer on GPU-0
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        shared_gossip_buffer, _references = make_gossip_buffer(
            flags, num_agents, proc_manager, device
        )

    base_logdir = flags.base_logdir
    learner_proc = []
    pantheon_proc = []
    stop_event = []
    for rank in range(num_agents):
        flags.base_logdir = (
            os.path.join(base_logdir, "gala_{}".format(rank))
            if num_agents > 1
            else base_logdir
        )
        init_logdirs(flags)

        # Unix domain socket path for RL server address, one per GALA agent.
        address = "/tmp/rl_server_path_{}".format(rank)
        try:
            os.remove(address)
        except OSError:
            pass
        flags.server_address = "unix:{}".format(address)

        # Round-robin device assignment for GALA
        if num_agents > 1 and torch.cuda.is_available():
            flags.learner_device = "cuda:{}".format(rank % torch.cuda.device_count())
            flags.inference_device = "cuda:{}".format(rank % torch.cuda.device_count())

        logging.info(
            "Starting agent {}. Mode={}, logdir={}".format(
                rank, flags.mode, flags.logdir
            )
        )

        stop_event.append(mp.Event())
        learner_proc.append(
            mp.Process(
                target=learner.main,
                kwargs=dict(
                    flags=flags,
                    rank=rank,
                    barrier=barrier,
                    gossip_buffer=shared_gossip_buffer,
                    stop_event=stop_event[-1],
                ),
                daemon=False,
            )
        )
        pantheon_proc.append(
            mp.Process(target=pantheon_env.main, args=(flags,), daemon=False)
        )
        learner_proc[rank].start()
        pantheon_proc[rank].start()

    # The shutdown sequence of a clean run is as follows:
    #   1. Wait until `stop_event` is set by the learner (=end of training notification)
    #   2. Kill the Pantheon process
    #   3. Clear `stop_event` to notify the learner it can exit (in particular, stop
    #      the RPC server).
    #   4. Wait until the learner process has exit
    # The motivation for this somewhat convoluted logic is that if we don't do #2 before
    # stopping the RPC server (in #3), then the Pantheon process will crash when the RPC
    # server is stopped, triggering meaningless error messages in the logs.
    for rank in range(num_agents):
        stop_event[rank].wait()
        logging.info(
            f"Stop event #{rank} set, will kill corresponding env (pid="
            f"{pantheon_proc[rank].pid})"
        )
        utils.kill_proc_tree(pantheon_proc[rank].pid)
        stop_event[rank].clear()
        learner_proc[rank].join()

    logging.info("Done training.")