def _plot_test_results()

in automl21/scs_neural/experimentation/launcher.py [0:0]


    def _plot_test_results(self, dataset_type='test', n_iter=10, tag='t', dir_tag=None):
        if dataset_type == 'validate':
            problems, multi_instance = self.scs_validate_problem, self.val_multi_instance 
            batch_size, graph_batch_size = self.cfg.validate_batch_size, self.cfg.validate_graph_batch_size
        else:
            problems, multi_instance = self.scs_test_problem, self.test_multi_instance
            batch_size, graph_batch_size = self.cfg.test_batch_size, self.cfg.test_graph_batch_size
        with torch.no_grad():
            if multi_instance.num_instances == batch_size:
                soln_neural, scs_neural_metrics = self.scs_neural.solve(
                    multi_instance, max_iters=n_iter, track_metrics=True, train=False)
            else:
                all_soln_neural, all_neural_metrics = [], []
                for i in range(0, multi_instance.num_instances, batch_size):
                    max_instance_id = min((i + batch_size), multi_instance.num_instances)
                    curr_test = self.scs_neural.select_instances(
                        multi_instance, 
                        [x for x in range(i, max_instance_id)])
                    soln_neural, scs_neural_metrics = self.scs_neural.solve(
                        curr_test, max_iters=n_iter, track_metrics=True, train=False)
                    all_soln_neural = all_soln_neural + soln_neural
                    all_neural_metrics = all_neural_metrics + scs_neural_metrics
                soln_neural, scs_neural_metrics = all_soln_neural, all_neural_metrics
        
        losses = [soln_neural[i]['loss'] for i in range(len(soln_neural))]
        loss, index_nans = self._compute_loss(losses)

        if dataset_type == 'validate':
            if loss > 0:
                self.val_loss_meter.update(loss.item())
            return

        if dataset_type == 'test':
            if loss > 0:
                self.test_loss_meter.update(loss.item())

        self.writer.writerow({
            'iter': self.itr,
            'train_loss': "%.6e" % (self.loss_meter.avg),
            'val_loss': "%.6e" %(self.val_loss_meter.avg),
            'test_loss': "%.6e" %(self.test_loss_meter.avg)
        })
        self.logf.flush()
        if self.cfg.log_tensorboard:
            self.sw.add_scalars("Loss", {"train": self.loss_meter.avg,
                                         "validate": self.val_loss_meter.avg,
                                         "test": self.test_loss_meter.avg},
                                self.itr)
        if loss == -1:
            return []

        if dir_tag is None:
            dir_tag = ""
        upd_dir_tag = "test/" + dir_tag

        x = [x for x in range(len(problems.instances))]
        if len(index_nans) > 0:
            x = [i for i in x if i not in index_nans]
        
        sampled_ids = np.random.choice(
            x,
            size=graph_batch_size,
            replace=False
        )

        agg_scs_neural, conf_scs_neural = self._extract_aggregate_metrics(
            scs_neural_metrics, soln_type='neural', index_nans=index_nans
        )
        self._plot_solution_results(sampled_ids, scs_neural_metrics, tag=tag,
                                    dir_tag=upd_dir_tag, title_stub='Test')
        return agg_scs_neural, conf_scs_neural