in fairdiplomacy/selfplay/exploit.py [0:0]
def __call__(self):
# we call this immediately to send the data workers off to the right places
# After this returns, we're in the master process or training helper processes only.
self._init_data_loader__workers_never_return()
self._init_training_device_and_state()
self._init_metrics_logger()
if self.ectx.is_training_master:
CKPT_MAIN_DIR.mkdir(exist_ok=True, parents=True)
if self.state.value_net_state is not None:
CKPT_VALUE_DIR.mkdir(exist_ok=True, parents=True)
# Training device and model state should be correct now, so now make sure that
# the model is sent to the data generation workers if it's not there yet
self.send_model_to_workers()
self.state.model.train()
if self.state.value_net_state is not None:
self.state.value_net_state.model.train()
if self.cfg.trainer.train_as_eval:
assert self.state.value_net_state is None, "Not supported"
self.state.model.eval()
# Cast cuDNN RNN back to train mode.
self.state.model.apply(_lstm_to_train)
elif self.cfg.trainer.train_encoder_as_eval:
assert self.state.value_net_state is None, "Not supported"
self.state.model.encoder.eval()
elif self.cfg.trainer.train_decoder_as_eval:
assert self.state.value_net_state is None, "Not supported"
self.state.model.policy_decoder.eval()
self.state.model.policy_decoder.apply(_lstm_to_train)
elif self.cfg.trainer.train_as_eval_but_batchnorm:
assert self.state.value_net_state is None, "Not supported"
self.state.model.eval()
self.state.model.apply(_lstm_to_train)
self.state.model.apply(_bn_to_train)
if not self.research:
assert self.cfg.num_train_gpus == 1, "Only one training GPU for policy gradients"
self.state.model.to(self.device)
else:
assert self.cfg.num_train_gpus in (1, 2, 4)
if self.cfg.num_train_gpus:
assert torch.cuda.device_count() == 8, "Can only go multi-gpu on full machine"
if self.cfg.num_train_gpus > 1 and self.cfg.use_distributed_data_parallel:
self.state.net_state.model = torch.nn.parallel.DistributedDataParallel(
self.state.net_state.model,
device_ids=(self.ectx.training_ddp_rank,),
output_device=self.ectx.training_ddp_rank,
)
if self.state.value_net_state is not None:
self.state.value_net_state.model = torch.nn.parallel.DistributedDataParallel(
self.state.value_net_state.model,
device_ids=(self.ectx.training_ddp_rank,),
output_device=self.ectx.training_ddp_rank,
)
else:
self.state.net_state.model = torch.nn.DataParallel(
self.state.net_state.model,
device_ids=tuple(range(self.cfg.num_train_gpus)),
)
if self.state.value_net_state is not None:
self.state.value_net_state.model = torch.nn.DataParallel(
self.state.value_net_state.model,
device_ids=tuple(range(self.cfg.num_train_gpus)),
)
else:
self.state.net_state.model.to(self.device)
if self.state.value_net_state is not None:
self.state.value_net_state.model.to(self.device)
if torch.cuda.is_available():
pynvml.nvmlInit()
if self.research and self.cfg.search_rollout.benchmark_only:
self.perform_benchmark()
return
# Installing requeue handle at the last moment so that we don't requeue zombie job.
heyhi.maybe_init_requeue_handler(self.on_requeue)
self.run_training_loop()