def test()

in aiops/ContraAD/solver.py [0:0]


    def test(self):
        self.model = accelerator.unwrap_model(self.model)
        self.model.load_state_dict(
            torch.load(
                os.path.join(
                    str(self.model_save_path), str(self.data_path) + f"_checkpoint_{self.win_size}.pth"
                )
            )
        )
        self.model.eval()
        temperature = 50

        # (1) stastic on the train set
        attens_energy = []
        for i, (input_data, labels) in enumerate(self.train_loader):
            input_data = input_data.to(self.device)
            intra = self.model(input_data)
            # inter_dis = cdist(inter) # b c c
            # intra_dis = cdist(intra)  # b w w
            # val,_ = intra_dis.sum(dim=1).min(dim=1)
            # val = repeat(val, "b -> b w", w=self.win_size)
            # metric = F.softmax(normalize(intra_dis.sum(dim=1)/val), dim=1)  # b w
            out = cal_metric(x=intra,z_score=None,mode=self.loss_mode ,soft=self.soft,soft_mode=self.soft_mode,model_mode='test')
            metric = F.softmax(out, dim=1)
            attens_energy = accelerator.gather_for_metrics((metric))
            attns_energy_collect.extend([item.detach().cpu().numpy() for item in attens_energy])
            # attens_energy.append(metric.detach().cpu().numpy())
        accelerator.wait_for_everyone()
        if accelerator.is_local_main_process:
            attens_energy = attns_energy_collect.copy()
            attns_energy_collect.clear()
            attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
            train_energy = np.array(attens_energy)

        # (2) find the threshold
        attens_energy = []
        for i, (input_data, labels) in enumerate(self.test_loader):
            input_data = input_data.to(self.device)
            intra = self.model(input_data)
            # metric = F.softmax(intra, dim=1)
            # intra_dis = cdist(intra)  # b w w
            # val,_ = intra_dis.sum(dim=1).min(dim=1)
            # val = repeat(val, "b -> b w", w=self.win_size)
            # metric = F.softmax(normalize(intra_dis.sum(dim=1)/val), dim=1)  # b w
            out = cal_metric(x=intra,z_score=None,mode=self.loss_mode ,soft=self.soft,soft_mode=self.soft_mode,model_mode='test')
            metric = F.softmax(out, dim=1)

            # attens_energy.append(metric.detach().cpu().numpy())
            attens_energy = accelerator.gather_for_metrics((metric))
            attns_energy_collect.extend([item.detach().cpu().numpy() for item in attens_energy])

            # self.attens_energy.extend([item.detach().cpu().numpy() for item in attens_energy])
            # attens_energy.append(metric.detach().cpu().numpy())
        accelerator.wait_for_everyone()
        if accelerator.is_local_main_process:
            attens_energy = attns_energy_collect.copy()
            attns_energy_collect.clear()
            attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
            test_energy = np.array(attens_energy)
            combined_energy = np.concatenate([train_energy, test_energy], axis=0)
            thresh = np.percentile(combined_energy, 100 - self.anormly_ratio)
     

        # (3) evaluation on the test set
        test_labels = []
        attens_energy = []
        point_labels = []
        for i, (input_data, labels) in enumerate(self.thre_loader):
            input_data = input_data.to(self.device)
            intra = self.model(input_data)
            # metric = F.softmax(intra, dim=1)
            out = cal_metric(x=intra,z_score=None,mode=self.loss_mode ,soft=self.soft,soft_mode=self.soft_mode,model_mode='test')
            metric = F.softmax(out, dim=1)

            # val,_ = intra_dis.sum(dim=1).min(dim=1)
            # val = repeat(val, "b -> b w", w=self.win_size)
            # metric = F.softmax(normalize(intra_dis.sum(dim=1)/val), dim=1)  # b w
            # attens_energy.append(metric.detach().cpu().numpy())
            # test_labels.append(labels)
            attens_energy,test_labels = accelerator.gather_for_metrics((metric,labels))
            attns_energy_collect.extend([item.detach().cpu().numpy() for item in attens_energy])
            test_labels_collect.extend([item.detach().cpu().numpy() for item in test_labels])
            # test_labels = accelerator.gather(labels)
        accelerator.wait_for_everyone()
        if accelerator.is_local_main_process:
            attens_energy = attns_energy_collect.copy() #[item.detach().cpu().numpy() for item in attens_energy]
            test_labels =  test_labels_collect.copy() #[item.detach().cpu().numpy() for item in test_labels]
            attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
            test_labels = np.concatenate(test_labels, axis=0).reshape(-1)

            test_energy = np.array(attens_energy)
            test_labels = np.array(test_labels)

            pred = (test_energy > thresh).astype(int)
            gt = test_labels.astype(int)
            print(len(gt),len(pred))

            anomaly_state = False
            for i in range(len(gt)):
                if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
                    anomaly_state = True
                    for j in range(i, 0, -1):
                        if gt[j] == 0:
                            break
                        else:
                            if pred[j] == 0:
                                pred[j] = 1
                    for j in range(i, len(gt)):
                        if gt[j] == 0:
                            break
                        else:
                            if pred[j] == 0:
                                pred[j] = 1
                elif gt[i] == 0:
                    anomaly_state = False
                if anomaly_state:
                    pred[i] = 1

            pred = np.array(pred)
            gt = np.array(gt)

            from sklearn.metrics import precision_recall_fscore_support
            from sklearn.metrics import accuracy_score

            accuracy = accuracy_score(gt, pred)
            precision, recall, f_score, support = precision_recall_fscore_support(
                gt, pred, average="binary"
            )
            result_dict = {
                "anomaly_ratio": self.anormly_ratio,
                "win_size": self.win_size,
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "f_score": f_score,
                "thre":thresh
            }
            print(result_dict)
            # if not os.path.exists(f"{self.dataset}.log"):
            with open(f"{self.dataset}.log",mode="a") as f:
                f.write(json.dumps(result_dict))
                f.write("\n")

            return accuracy, precision, recall, f_score