in train/train.py [0:0]
def train(flags):
flags.mode = "train"
flags.cc_env_mode = "remote"
if torch.cuda.is_available():
flags.learner_device = "cuda:0"
flags.inference_device = "cuda:1"
# For GALA
proc_manager = mp.Manager()
barrier = None
shared_gossip_buffer = None
# In GALA mode, start multiple replicas of the torchbeast-pantheon setup.
num_agents = 1
if flags.num_gala_agents > 1:
num_agents = flags.num_gala_agents
logging.info("In GALA mode, will start {} agents".format(num_agents))
barrier = proc_manager.Barrier(num_agents)
# Shared-gossip-buffer on GPU-0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
shared_gossip_buffer, _references = make_gossip_buffer(
flags, num_agents, proc_manager, device
)
base_logdir = flags.base_logdir
learner_proc = []
pantheon_proc = []
stop_event = []
for rank in range(num_agents):
flags.base_logdir = (
os.path.join(base_logdir, "gala_{}".format(rank))
if num_agents > 1
else base_logdir
)
init_logdirs(flags)
# Unix domain socket path for RL server address, one per GALA agent.
address = "/tmp/rl_server_path_{}".format(rank)
try:
os.remove(address)
except OSError:
pass
flags.server_address = "unix:{}".format(address)
# Round-robin device assignment for GALA
if num_agents > 1 and torch.cuda.is_available():
flags.learner_device = "cuda:{}".format(rank % torch.cuda.device_count())
flags.inference_device = "cuda:{}".format(rank % torch.cuda.device_count())
logging.info(
"Starting agent {}. Mode={}, logdir={}".format(
rank, flags.mode, flags.logdir
)
)
stop_event.append(mp.Event())
learner_proc.append(
mp.Process(
target=learner.main,
kwargs=dict(
flags=flags,
rank=rank,
barrier=barrier,
gossip_buffer=shared_gossip_buffer,
stop_event=stop_event[-1],
),
daemon=False,
)
)
pantheon_proc.append(
mp.Process(target=pantheon_env.main, args=(flags,), daemon=False)
)
learner_proc[rank].start()
pantheon_proc[rank].start()
# The shutdown sequence of a clean run is as follows:
# 1. Wait until `stop_event` is set by the learner (=end of training notification)
# 2. Kill the Pantheon process
# 3. Clear `stop_event` to notify the learner it can exit (in particular, stop
# the RPC server).
# 4. Wait until the learner process has exit
# The motivation for this somewhat convoluted logic is that if we don't do #2 before
# stopping the RPC server (in #3), then the Pantheon process will crash when the RPC
# server is stopped, triggering meaningless error messages in the logs.
for rank in range(num_agents):
stop_event[rank].wait()
logging.info(
f"Stop event #{rank} set, will kill corresponding env (pid="
f"{pantheon_proc[rank].pid})"
)
utils.kill_proc_tree(pantheon_proc[rank].pid)
stop_event[rank].clear()
learner_proc[rank].join()
logging.info("Done training.")