in fairdiplomacy/selfplay/exploit.py [0:0]
def run_training_loop(self):
logging.info("Beginning training loop")
max_epochs = self.cfg.trainer.max_epochs or 10 ** 9
for self.state.epoch_id in range(self.state.epoch_id, max_epochs):
# Clone state each epoch in case we'll need to requeue.
if self.ectx.is_training_master:
self.state.save(REQUEUE_CKPT)
if (
self.cfg.trainer.save_checkpoint_every
and self.state.epoch_id % self.cfg.trainer.save_checkpoint_every == 0
):
self.state.net_state.save(CKPT_MAIN_DIR / (CKPT_TPL % self.state.epoch_id))
if self.state.value_net_state is not None:
self.state.value_net_state.save(
CKPT_VALUE_DIR / (CKPT_TPL % self.state.epoch_id)
)
# Counter accumulate different statistic over the epoch. Default
# accumulation strategy is averaging.
counters = collections.defaultdict(fairdiplomacy.selfplay.metrics.FractionCounter)
use_grad_clip = (self.cfg.optimizer.grad_clip or 0) > 1e-10
if use_grad_clip:
counters["optim/grad_max"] = fairdiplomacy.selfplay.metrics.MaxCounter()
if not self.research:
counters["score/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter()
for p in POWERS:
counters[f"score_{p}/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter()
# For LR just record its value at the start of the epoch.
counters["optim/lr"].update(next(iter(self.state.optimizer.param_groups))["lr"])
if self.state.value_net_state is not None:
counters["optim/lr_value"].update(
next(iter(self.state.value_net_state.optimizer.param_groups))["lr"]
)
epoch_start_time = time.time()
for _ in range(self.cfg.trainer.epoch_size):
self.do_step(counters=counters, use_grad_clip=use_grad_clip)
if (
self.state.global_step < 128
or (self.state.global_step & self.state.global_step + 1) == 0
):
logging.info(
"Metrics (global_step=%d): %s",
self.state.global_step,
{k: v.value() for k, v in sorted(counters.items())},
)
self.state.global_step += 1
if self.research:
# TODO(akhti): I don't like this if.
for key, value in self.data_loader.get_buffer_stats(prefix="buffer/").items():
counters[key].update(value)
epoch_scalars = {k: v.value() for k, v in sorted(counters.items())}
average_batch_size = epoch_scalars["size/batch"]
epoch_scalars["speed/loop_bps"] = self.cfg.trainer.epoch_size / (
time.time() - epoch_start_time + 1e-5
)
epoch_scalars["speed/loop_eps"] = epoch_scalars["speed/loop_bps"] * average_batch_size
# Speed for to_cuda + forward + backward.
torch_time = epoch_scalars["time/net"] + epoch_scalars["time/to_cuda"]
epoch_scalars["speed/train_bps"] = 1.0 / torch_time
epoch_scalars["speed/train_eps"] = average_batch_size / torch_time
if torch.cuda.is_available():
for i in range(pynvml.nvmlDeviceGetCount()):
mem_info = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(i))
epoch_scalars[f"gpu_mem_used/{i}"] = mem_info.used / 2 ** 30
epoch_scalars[f"gpu_mem_free/{i}"] = mem_info.free / 2 ** 30
mem_stats = psutil.virtual_memory()
epoch_scalars["memory/used_gb"] = mem_stats.used / 2 ** 30
epoch_scalars["memory/available_gb"] = mem_stats.available / 2 ** 30
epoch_scalars["memory/free_gb"] = mem_stats.free / 2 ** 30
eval_scores = self.data_loader.extract_eval_scores()
if eval_scores is not None:
for k, v in eval_scores.items():
epoch_scalars[f"score_eval/{k}"] = v
logging.info(
"Finished epoch %d. Metrics:\n%s",
self.state.epoch_id,
format_metrics_for_print(epoch_scalars),
)
self.logger.log_metrics(epoch_scalars, self.state.epoch_id)
if self.state.scheduler is not None:
self.state.scheduler.step()
if self.state.value_net_state is not None and self.state.value_net_state.scheduler:
self.state.value_net_state.scheduler.step()
logging.info("End of training")
logging.info("Exiting main funcion")