def run_trainer()

in cfvpy/selfplay.py [0:0]


    def run_trainer(self):
        # Fix version so that training always continues.
        if self.is_master:
            logger = pl_logging.TestTubeLogger(save_dir=os.getcwd(), version=0)

        # Storing the whole dict to preserve ref_models.
        datagen = self.initialize_datagen()
        context = datagen["context"]
        replay = datagen["replay"]
        policy_replay = datagen["policy_replay"]

        if self.cfg.data.train_preload:
            # Must preload data before starting generators to avoid deadlocks.
            _preload_data(self.cfg.data.train_preload, replay)
            preloaded_size = replay.size()
        else:
            preloaded_size = 0

        self.opt, self.policy_opt = self.configure_optimizers()
        self.scheduler = self.configure_scheduler(self.opt)

        context.start()

        if self.cfg.benchmark_data_gen:
            # Benchmark generation speed and exit.
            time.sleep(self.cfg.benchmark_data_gen)
            context.terminate()
            size = replay.num_add()
            logging.info(
                "BENCHMARK size %s speed %.2f", size, size / context.running_time
            )
            return

        train_size = self.cfg.data.train_epoch_size or 128 * 1000
        logging.info("Train set size (forced): %s", train_size)

        assert self.cfg.data.train_batch_size
        batch_size = self.cfg.data.train_batch_size
        epoch_size = train_size // batch_size

        if self.is_master:
            val_datasets = []

        logging.info(
            "model size is %s",
            sum(p.numel() for p in self.net.parameters() if p.requires_grad),
        )
        save_dir = pathlib.Path("ckpt")
        if self.is_master and not save_dir.exists():
            logging.info(f"Creating savedir: {save_dir}")
            save_dir.mkdir(parents=True)

        burn_in_frames = batch_size * 2
        while replay.size() < burn_in_frames or (
            policy_replay is not None and policy_replay.size() < burn_in_frames
        ):
            logging.info(
                "warming up replay buffer: %d/%d", replay.size(), burn_in_frames
            )
            if policy_replay is not None:
                logging.info(
                    "warming up POLICY replay buffer: %d/%d",
                    policy_replay.size(),
                    burn_in_frames,
                )
            time.sleep(30)

        def compute_gen_bps():
            return (
                (replay.num_add() - preloaded_size) / context.running_time / batch_size
            )

        def compute_gen_bps_policy():
            return policy_replay.num_add() / context.running_time / batch_size

        metrics = None
        num_decays = 0
        for epoch in range(self.cfg.max_epochs):
            self.train_timer.start("start")
            if (
                epoch % self.cfg.decrease_lr_every == self.cfg.decrease_lr_every - 1
                and self.scheduler is None
            ):
                if (
                    not self.cfg.decrease_lr_times
                    or num_decays < self.cfg.decrease_lr_times
                ):
                    for param_group in self.opt.param_groups:
                        param_group["lr"] /= 2
                    num_decays += 1
            if (
                self.cfg.create_validation_set_every
                and self.is_master
                and epoch % self.cfg.create_validation_set_every == 0
            ):
                logging.info("Adding new validation set")
                val_batches = [
                    replay.sample(batch_size, "cpu")[0]
                    for _ in range(512 * 100 // batch_size)
                ]
                val_datasets.append((f"valid_snapshot_{epoch:04d}", val_batches))

            if (
                self.cfg.selfplay.dump_dataset_every_epochs
                and epoch % self.cfg.selfplay.dump_dataset_every_epochs == 0
                and (not self.cfg.data.train_preload or epoch > 0)
            ):
                dataset_folder = pathlib.Path("dumped_data").resolve()
                dataset_folder.mkdir(exist_ok=True, parents=True)
                dataset_path = dataset_folder / f"data_{epoch:03d}.dat"
                logging.info(
                    "Saving replay buffer as supervised dataset to %s", dataset_path
                )
                replay.save(str(dataset_path))

            metrics = {}
            metrics["optim/lr"] = next(iter(self.opt.param_groups))["lr"]
            metrics["epoch"] = epoch
            counters = collections.defaultdict(cfvpy.utils.FractionCounter)
            if self.cfg.grad_clip:
                counters["optim/grad_max"] = cfvpy.utils.MaxCounter()
                if self.cfg.train_policy:
                    counters["optim_policy/grad_max"] = cfvpy.utils.MaxCounter()
            use_progress_bar = not heyhi.is_on_slurm() or self.cfg.show_progress_bar
            train_loader = range(epoch_size)
            train_device = self.device
            train_iter = tqdm.tqdm(train_loader) if use_progress_bar else train_loader
            training_start = time.time()

            if self.cfg.train_gen_ratio:
                while True:
                    if replay.num_add() * self.cfg.train_gen_ratio >= train_size * (
                        epoch + 1
                    ):
                        break
                    logging.info(
                        "Throttling to satisfy |replay| * ratio >= train_size * epochs:"
                        " %s * %s >= %s %s",
                        replay.num_add(),
                        self.cfg.train_gen_ratio,
                        train_size,
                        epoch + 1,
                    )
                    time.sleep(60)
            assert self.cfg.replay.use_priority is False, "Not supported"

            value_loss = policy_loss = 0  # For progress bar.
            for iter_id in train_iter:
                self.train_timer.start("train-get_batch")
                use_policy_net = iter_id % 2 and policy_replay is not None
                if use_policy_net:
                    batch, _ = policy_replay.sample(batch_size, train_device)
                    suffix = "_policy"
                else:
                    batch, _ = replay.sample(batch_size, train_device)
                    suffix = ""
                self.train_timer.start("train-forward")
                self.net.train()
                loss_dict = self._compute_loss_dict(
                    batch, train_device, use_policy_net, timer_prefix="train-"
                )
                self.train_timer.start("train-backward")
                loss = loss_dict["loss"]
                opt = self.policy_opt if use_policy_net else self.opt
                params = (
                    self.get_policy_params()
                    if use_policy_net
                    else self.get_value_params()
                )
                opt.zero_grad()
                loss.backward()

                if self.cfg.grad_clip:
                    g_norm = clip_grad_norm_(params, self.cfg.grad_clip)
                else:
                    g_norm = None
                opt.step()
                loss.item()  # Force sync.
                self.train_timer.start("train-rest")
                if g_norm is not None:
                    g_norm = g_norm.item()
                    counters[f"optim{suffix}/grad_max"].update(g_norm)
                    counters[f"optim{suffix}/grad_mean"].update(g_norm)
                    counters[f"optim{suffix}/grad_clip_ratio"].update(
                        int(g_norm >= self.cfg.grad_clip - 1e-5)
                    )
                counters[f"loss{suffix}/train"].update(loss)
                for num_cards, partial_data in loss_dict["partials"].items():
                    counters[f"loss{suffix}/train_{num_cards}"].update(
                        partial_data["loss_sum"], partial_data["count"],
                    )
                    counters[f"val{suffix}/train_{num_cards}"].update(
                        partial_data["val_sum"], partial_data["count"],
                    )
                    counters[f"shares{suffix}/train_{num_cards}"].update(
                        partial_data["count"], batch_size
                    )

                if use_progress_bar:
                    if use_policy_net:
                        policy_loss = loss.detach().item()
                    else:
                        value_loss = loss.detach().item()
                    pbar_fields = dict(
                        policy_loss=policy_loss,
                        value_loss=value_loss,
                        buffer_size=replay.size(),
                        gen_bps=compute_gen_bps(),
                    )
                    if policy_replay is not None:
                        pbar_fields["pol_buffer_size"] = policy_replay.size()
                    train_iter.set_postfix(**pbar_fields)
                if self.cfg.fake_training:
                    # Generation benchmarking mode in which training is
                    # skipped. The goal is to measure generation speed withot
                    # sample() calls..
                    break
            if self.cfg.fake_training:
                # Fake training epoch takes a minute.
                time.sleep(60)

            if len(train_loader) > 0:
                metrics["bps/train"] = len(train_loader) / (
                    time.time() - training_start
                )
                metrics["bps/train_examples"] = metrics["bps/train"] * batch_size
            logging.info(
                "[Train] epoch %d complete, avg error is %f",
                epoch,
                counters["loss/train"].value(),
            )
            if self.scheduler is not None:
                self.scheduler.step()
            for name, counter in counters.items():
                metrics[name] = counter.value()
            metrics["buffer/size"] = replay.size()
            metrics["buffer/added"] = replay.num_add()
            metrics["bps/gen"] = compute_gen_bps()
            metrics["bps/gen_examples"] = metrics["bps/gen"] * batch_size
            if policy_replay is not None:
                metrics["buffer/policy_size"] = policy_replay.size()
                metrics["buffer/policy_added"] = policy_replay.num_add()
                metrics["bps/gen_policy"] = compute_gen_bps_policy()
                metrics["bps/gen_policy_examples"] = (
                    metrics["bps/gen_policy"] * batch_size
                )

            if (epoch + 1) % self.cfg.selfplay.network_sync_epochs == 0 or epoch < 15:
                logging.info("Copying current network to the eval network")
                for model_locker in datagen["model_lockers"]:
                    model_locker.update_model(self.get_model())
            if self.cfg.purging_epochs and (epoch + 1) in self.cfg.purging_epochs:
                new_size = max(
                    burn_in_frames,
                    int((self.cfg.purging_share_keep or 0.0) * replay.size()),
                )
                logging.info(
                    "Going to purge everything but %d elements in the buffer", new_size,
                )
                replay.pop_until(new_size)

            if self.is_master and epoch % 10 == 0:
                with torch.no_grad():
                    for i, (name, val_loader) in enumerate(val_datasets):
                        self.train_timer.start("valid-acc-extra")
                        eval_errors = []
                        val_iter = (
                            tqdm.tqdm(val_loader, desc="Eval")
                            if use_progress_bar
                            else val_loader
                        )
                        for data in val_iter:
                            self.net.eval()
                            loss = self._compute_loss_dict(
                                data, train_device, use_policy_net=False
                            )["loss"]
                            eval_errors.append(loss.detach().item())
                        current_error = sum(eval_errors) / len(eval_errors)
                        logging.info(
                            "[Eval] epoch %d complete, data is %s, avg error is %f",
                            epoch,
                            name,
                            current_error,
                        )
                        metrics[f"loss/{name}"] = current_error

                self.train_timer.start("valid-trace")
                ckpt_path = save_dir / f"epoch{epoch}.ckpt"
                torch.save(self.get_model().state_dict(), ckpt_path)
                bin_path = ckpt_path.with_suffix(".torchscript")
                torch.jit.save(torch.jit.script(self.get_model()), str(bin_path))

                self.train_timer.start("valid-exploit")
                if self.cfg.exploit and epoch % 20 == 0:
                    bin_path = pathlib.Path("tmp.torchscript")
                    torch.jit.save(torch.jit.script(self.get_model()), str(bin_path))
                    (
                        exploitability,
                        mse_net_traverse,
                        mse_fp_traverse,
                    ) = cfvpy.rela.compute_stats_with_net(
                        create_mdp_config(self.cfg.env), str(bin_path)
                    )
                    logging.info(
                        "Exploitability to leaf (epoch=%d): %.2f", epoch, exploitability
                    )
                    metrics["exploitability_last"] = exploitability
                    metrics["eval_mse/net_reach"] = mse_net_traverse
                    metrics["eval_mse/fp_reach"] = mse_fp_traverse

            if len(train_loader) > 0:
                metrics["bps/loop"] = len(train_loader) / (time.time() - training_start)
            total = 1e-5
            for k, v in self.train_timer.timings.items():
                metrics[f"timing/{k}"] = v / (epoch + 1)
                total += v
            for k, v in self.train_timer.timings.items():
                metrics[f"timing_pct/{k}"] = v * 100 / total
            logging.info("Metrics: %s", metrics)
            if self.is_master:
                logger.log_metrics(metrics)
                logger.save()
        return metrics