in fairdiplomacy/selfplay/exploit.py [0:0]
def _init_training_device_and_state(self):
if not torch.cuda.is_available():
logging.warning("No CUDA found!")
self.device = "cpu"
else:
if self.cfg.use_distributed_data_parallel:
# For now, ddp rank equals gpu index
self.device = f"cuda:{self.ectx.training_ddp_rank}" # Training device.
else:
self.device = "cuda" # Training device.
if self.cfg.value_model_path is not None:
net, net_args = load_diplomacy_model_model_and_args(
self.cfg.model_path,
map_location=self.device,
eval=True,
override_has_policy=True,
override_has_value=False,
)
optim, lr_scheduler = build_optimizer(net, self.cfg.optimizer)
value_net, value_net_args = load_diplomacy_model_model_and_args(
self.cfg.value_model_path,
map_location=self.device,
eval=True,
override_has_policy=False,
override_has_value=True,
)
value_optim, value_lr_scheduler = build_optimizer(
value_net, self.cfg.value_optimizer or self.cfg.optimizer
)
value_state = NetTrainingState(
args=value_net_args,
model=value_net,
optimizer=value_optim,
scheduler=value_lr_scheduler,
)
else:
net, net_args = load_diplomacy_model_model_and_args(
self.cfg.model_path, map_location=self.device, eval=True,
)
optim, lr_scheduler = build_optimizer(net, self.cfg.optimizer)
value_state = None
self.state = TrainerState(
net_state=NetTrainingState(
model=net, optimizer=optim, scheduler=lr_scheduler, args=net_args,
),
value_net_state=value_state,
)
if self.cfg.reset_agent_weights:
def _reset(module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
net.apply(_reset)
if REQUEUE_CKPT.exists():
logging.info("Found requeue checkpoint: %s", REQUEUE_CKPT.resolve())
self.state = TrainerState.load(REQUEUE_CKPT, self.state, self.device)
elif self.cfg.requeue_ckpt_path:
logging.info("Using explicit requeue checkpoint: %s", self.cfg.requeue_ckpt_path)
p = pathlib.Path(self.cfg.requeue_ckpt_path)
assert p.exists(), p
self.state = TrainerState.load(p, self.state, self.device)
else:
if self.state.value_net_state is not None:
if CKPT_VALUE_DIR.exists() and list(CKPT_VALUE_DIR.iterdir()):
last_ckpt = max(CKPT_VALUE_DIR.iterdir(), key=str)
logging.info(
"Found existing VALUE checkpoint folder. Will load last one: %s", last_ckpt
)
self.state.value_net_state = NetTrainingState.load(
last_ckpt, self.state.value_net_state, self.device
)
else:
logging.info("No VALUE checkpoint found")
if self.state.net_state is not None:
if CKPT_MAIN_DIR.exists() and list(CKPT_MAIN_DIR.iterdir()):
last_ckpt = max(CKPT_MAIN_DIR.iterdir(), key=str)
logging.info(
"Found existing MAIN checkpoint folder. Will load last one: %s", last_ckpt
)
self.state.net_state = NetTrainingState.load(
last_ckpt, self.state.net_state, self.device
)
else:
logging.info("No MAIN checkpoint found")
if self.cfg.use_distributed_data_parallel:
torch_ddp_init_fname = get_torch_ddp_init_fname()
logging.info(f"Using {torch_ddp_init_fname} to coordinate launch of ddp processes.")
logging.info("Waiting for distributed data parallel helpers to sync up")
torch.distributed.init_process_group(
"nccl",
init_method=f"file://{torch_ddp_init_fname}",
rank=self.ectx.training_ddp_rank,
world_size=self.ectx.ddp_world_size,
)
logging.info("Distributed data parallel helpers synced, proceeding with training")