def run_training_loop()

in fairdiplomacy/selfplay/exploit.py [0:0]


    def run_training_loop(self):
        logging.info("Beginning training loop")
        max_epochs = self.cfg.trainer.max_epochs or 10 ** 9
        for self.state.epoch_id in range(self.state.epoch_id, max_epochs):
            # Clone state each epoch in case we'll need to requeue.
            if self.ectx.is_training_master:
                self.state.save(REQUEUE_CKPT)
                if (
                    self.cfg.trainer.save_checkpoint_every
                    and self.state.epoch_id % self.cfg.trainer.save_checkpoint_every == 0
                ):
                    self.state.net_state.save(CKPT_MAIN_DIR / (CKPT_TPL % self.state.epoch_id))
                    if self.state.value_net_state is not None:
                        self.state.value_net_state.save(
                            CKPT_VALUE_DIR / (CKPT_TPL % self.state.epoch_id)
                        )

            # Counter accumulate different statistic over the epoch. Default
            # accumulation strategy is averaging.
            counters = collections.defaultdict(fairdiplomacy.selfplay.metrics.FractionCounter)
            use_grad_clip = (self.cfg.optimizer.grad_clip or 0) > 1e-10
            if use_grad_clip:
                counters["optim/grad_max"] = fairdiplomacy.selfplay.metrics.MaxCounter()
            if not self.research:
                counters["score/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter()
                for p in POWERS:
                    counters[f"score_{p}/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter()
            # For LR just record its value at the start of the epoch.
            counters["optim/lr"].update(next(iter(self.state.optimizer.param_groups))["lr"])
            if self.state.value_net_state is not None:
                counters["optim/lr_value"].update(
                    next(iter(self.state.value_net_state.optimizer.param_groups))["lr"]
                )
            epoch_start_time = time.time()
            for _ in range(self.cfg.trainer.epoch_size):
                self.do_step(counters=counters, use_grad_clip=use_grad_clip)
                if (
                    self.state.global_step < 128
                    or (self.state.global_step & self.state.global_step + 1) == 0
                ):
                    logging.info(
                        "Metrics (global_step=%d): %s",
                        self.state.global_step,
                        {k: v.value() for k, v in sorted(counters.items())},
                    )
                self.state.global_step += 1
            if self.research:
                # TODO(akhti): I don't like this if.
                for key, value in self.data_loader.get_buffer_stats(prefix="buffer/").items():
                    counters[key].update(value)

            epoch_scalars = {k: v.value() for k, v in sorted(counters.items())}
            average_batch_size = epoch_scalars["size/batch"]
            epoch_scalars["speed/loop_bps"] = self.cfg.trainer.epoch_size / (
                time.time() - epoch_start_time + 1e-5
            )
            epoch_scalars["speed/loop_eps"] = epoch_scalars["speed/loop_bps"] * average_batch_size
            # Speed for to_cuda + forward + backward.
            torch_time = epoch_scalars["time/net"] + epoch_scalars["time/to_cuda"]
            epoch_scalars["speed/train_bps"] = 1.0 / torch_time
            epoch_scalars["speed/train_eps"] = average_batch_size / torch_time

            if torch.cuda.is_available():
                for i in range(pynvml.nvmlDeviceGetCount()):
                    mem_info = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(i))
                    epoch_scalars[f"gpu_mem_used/{i}"] = mem_info.used / 2 ** 30
                    epoch_scalars[f"gpu_mem_free/{i}"] = mem_info.free / 2 ** 30

            mem_stats = psutil.virtual_memory()
            epoch_scalars["memory/used_gb"] = mem_stats.used / 2 ** 30
            epoch_scalars["memory/available_gb"] = mem_stats.available / 2 ** 30
            epoch_scalars["memory/free_gb"] = mem_stats.free / 2 ** 30

            eval_scores = self.data_loader.extract_eval_scores()
            if eval_scores is not None:
                for k, v in eval_scores.items():
                    epoch_scalars[f"score_eval/{k}"] = v

            logging.info(
                "Finished epoch %d. Metrics:\n%s",
                self.state.epoch_id,
                format_metrics_for_print(epoch_scalars),
            )
            self.logger.log_metrics(epoch_scalars, self.state.epoch_id)

            if self.state.scheduler is not None:
                self.state.scheduler.step()
            if self.state.value_net_state is not None and self.state.value_net_state.scheduler:
                self.state.value_net_state.scheduler.step()

        logging.info("End of training")
        logging.info("Exiting main funcion")