in cfvpy/selfplay.py [0:0]
def __init__(self, cfg):
self.cfg = cfg
self.device = cfg.device or "cuda"
ckpt_path = "."
if heyhi.is_on_slurm():
self.rank = int(os.environ["SLURM_PROCID"])
self.is_master = self.rank == 0
n_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
else:
self.rank = 0
self.is_master = True
n_nodes = 1
logging.info(
"Setup: is_master=%s n_nodes=%s rank=%s ckpt_path=%s",
self.is_master,
n_nodes,
self.rank,
ckpt_path,
)
self.num_actions = cfg.env.num_dice * cfg.env.num_faces * 2 + 1
self.net = _build_model(self.device, self.cfg.env, self.cfg.model)
if self.is_master:
if self.cfg.load_checkpoint:
logging.info("Loading checkpoint: %s", self.cfg.load_checkpoint)
self.net.load_state_dict(
torch.load(self.cfg.load_checkpoint),
strict=not self.cfg.load_checkpoint_loose,
)
if self.cfg.selfplay.data_parallel:
logging.info("data parallel")
assert self.cfg.selfplay.num_master_threads == 0
self.net = torch.nn.DataParallel(self.net)
else:
logging.info("Single machine mode")
self.train_timer = cfvpy.utils.MultiStopWatchTimer()
if cfg.seed:
logging.info("Setting pytorch random seed to %s", cfg.seed)
torch.manual_seed(cfg.seed)