def _learn_batched()

in automl21/scs_neural/experimentation/launcher.py [0:0]


    def _learn_batched(self):
        self.lowest_val_loss = -1
        if self.cfg.use_train_seed:
            torch.manual_seed(self.cfg.train_seed)
        rng_train_data = np.random.default_rng(self.cfg.train_data_seed)
        self.scs_neural.create_model(self.scs_problem)
        if self.cfg.log_tensorboard:
            self.sw = SummaryWriter(log_dir=self.cfg.tensorboard_dir)
        self.opt = torch.optim.Adam(
            self.scs_neural.accel.parameters(),
            lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2))
        if self.cfg.cosine_lr_decay:
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.opt, self.cfg.num_model_updates)
        self.multi_instance = self.scs_neural.scale_and_cache_all_instances(
            self.scs_problem.instances, use_scaling=self.cfg.scs.use_problem_scaling,
            scale=self.cfg.scs.scale, rho_x=self.cfg.scs.rho_x
        )
        with torch.no_grad():
            self.val_multi_instance = self.scs_neural.scale_and_cache_all_instances(
                self.scs_validate_problem.instances, use_scaling=self.cfg.scs.use_problem_scaling,
                scale=self.cfg.scs.scale, rho_x=self.cfg.scs.rho_x
            )
            self.test_multi_instance = self.scs_neural.scale_and_cache_all_instances(
                self.scs_test_problem.instances, use_scaling=self.cfg.scs.use_problem_scaling,
                scale=self.cfg.scs.scale, rho_x=self.cfg.scs.rho_x
            )
        self._reset_diffu_counts()
        while self.itr < self.cfg.num_model_updates:
            sampled_ids = rng_train_data.choice(len(self.scs_problem.instances),
                                                size=self.cfg.train_batch_size,
                                                replace=False)
            num_tries = 0
            while True:
                num_tries += 1
                curr_multi_instance = self.scs_neural.select_instances(
                    self.multi_instance, sampled_ids)
                soln_neural, metrics, diffu_counts, loss_available = self.scs_neural.solve(
                        curr_multi_instance,
                        max_iters=self.cfg.num_iterations_train,
                        alpha=self.cfg.scs.alpha
                )
                if loss_available:
                    break
                if num_tries > 10000:
                    raise RuntimeError("Unable to find feasible train samples")
            losses = [soln_neural[i]['loss'] for i in range(self.cfg.train_batch_size)]
            loss, index_nans = self._compute_loss(losses)
            self.loss_meter.update(loss.item())
            self._update_diffu_counts(diffu_counts)
            self.opt.zero_grad()
            loss.backward()
            if self.cfg.clip_gradients:
                torch.nn.utils.clip_grad_norm_(
                    self.scs_neural.accel.parameters(),
                    self.cfg.max_gradient)
            if self.itr % self.cfg.test_freq == 0:
                if len(index_nans) == 0:
                    if self.cfg.log_tensorboard and hasattr(self.scs_neural.accel, 'log'):
                        self.scs_neural.accel.log(self.sw, self.itr)
            self.opt.step()
            if self.cfg.cosine_lr_decay:
                self.scheduler.step()
            if self.itr % self.cfg.test_freq == 0:
                print("Loss: ", loss.item())
                self._plot_test_results(
                    n_iter=self.cfg.num_iterations_eval,
                    dataset_type='validate'
                )
                test_results = self._plot_test_results(
                    n_iter=self.cfg.num_iterations_eval, tag=f'{self.itr:06d}',
                    dir_tag=f'{self.itr // 1000:03d}'
                )
                train_results = self._plot_train_results(
                    n_iter=self.cfg.num_iterations_eval, tag=f'{self.itr:06d}',
                    dir_tag=f'{self.itr // 1000:03d}'
                )
                if len(test_results) > 0 and len(train_results) > 0:
                    self.plot_aggregate_results(
                        test_results, train_results,
                        tag=f'{self.itr:06d}',
                        dir_tag=f'{self.itr // 1000:03d}'
                    )
                self._reset_diffu_counts()
            if self.itr % self.cfg.save_freq == 0:
                torch.save(self, 'latest.pt')
                if self.val_loss_meter.avg < self.lowest_val_loss or \
                   self.lowest_val_loss == -1:
                    torch.save(self, 'best_model.pt')
                    self.lowest_val_loss = self.val_loss_meter.avg
            self.itr += 1