def test_ood()

in trainers/catex.py [0:0]


    def test_ood(self, split=None, model_directory=''):
        """A generic OOD testing pipeline."""
        from tqdm import tqdm
        import os
        import os.path as osp
        from torch.utils.data import DataLoader
        import numpy as np

        from ood.datasets import CLIPFeatDataset
        from ood.datasets import SCOODDataset, LargeOODDataset, SemanticOODDataset, ClassOODDataset
        from ood.metrics import get_msp_scores, get_measures

        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT

        if self.cfg.TRAINER.FEAT_AS_INPUT:
            feat_data_dir = self.dm.dataset.dataset_dir+'/clip_feat'
            if not osp.exists(feat_data_dir):
                self.cache_feat(split=split, is_ood=False)

            data_loader = DataLoader(
                    CLIPFeatDataset(feat_data_dir, self.start_epoch, split='test'), 
                    batch_size=self.test_loader.batch_size, shuffle=False,
                    num_workers=self.test_loader.num_workers, pin_memory=True, drop_last=False, 
                )
        else:
            if split == "val" and self.val_loader is not None:
                data_loader = self.val_loader
            else:
                split = "test"  # in case val_loader is None
                data_loader = self.test_loader
        lab2cname = self.dm.dataset.lab2cname

        ood_cfg = {
            'SCOOD': ['texture', 'svhn', 'cifar', 'tin', 'lsun', 'places365'],
            'LargeOOD': ['inaturalist', 'sun', 'places', 'texture'],
        }
        data_root = osp.abspath(osp.expanduser(self.cfg.DATASET.ROOT))
        ood_type = 'SCOOD' if 'cifar' in self.dm.dataset.dataset_name else 'LargeOOD' # LargeOOD, ClassOOD
        if 'apply_' in self.cfg.TRAINER.OOD_INFER_OPTION:
            posthoc = self.cfg.TRAINER.OOD_INFER_OPTION
        else:
            posthoc = None

        if self.cfg.TRAINER.OOD_PROMPT:
            if self.cfg.TRAINER.OOD_PROMPT_NUM > 1:
                ood_text_features = torch.stack([self.model.get_text_features(ood_prompt=True, ood_prompt_idx=i) for i in range(self.cfg.TRAINER.OOD_PROMPT_NUM)])
            else:
                ood_text_features = self.model.get_text_features(ood_prompt=True)
                if self.cfg.TRAINER.CATEX.CTX_INIT:
                    assert self.cfg.TRAINER.CATEX.CTX_INIT == 'ensemble_learned'
                    ood_text_features = self.model.prompt_ensemble(ood_text_features)
            min_thresh = 0.51 if any(flag in model_directory for flag in ['/imagenet/', '/imagenet100-MCM-SCTX8-Orth/']) else 0.5

        self.model.text_feature_ensemble = self.model.get_text_features()

        print(f"Evaluate on the *{split}* set")

        if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res':
            save_dir = f'{model_directory}/restore'
            os.makedirs(save_dir, exist_ok=True)
            with open(f'{save_dir}/lab2cname.json', 'w+') as f:
                json.dump(lab2cname, f, indent=4)

            text_features = self.model.get_text_features()
            torch.save(text_features.cpu(), f'{save_dir}/in_text_features.pt')
            if self.cfg.TRAINER.OOD_PROMPT:
                torch.save(ood_text_features.cpu(), f'{save_dir}/ood_text_features.pt')

            im_feats, im_labels = [], []

        if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
            resume_dir = 'weights/imagenet100-MCM/CATEX/vit_b16_ep50_-1shots/nctx16_cscTrue_ctpend/seed1/restore'
            resume_image_features = torch.load(f'{resume_dir}/in_image_features.pt').to(self.device)
            resume_image_labels = torch.load(f'{resume_dir}/in_labels.pt').to(self.device)
            resume_text_features = torch.load(f'{resume_dir}/in_text_features.pt').to(self.device)
            if self.cfg.TRAINER.OOD_PROMPT:
                resume_ood_text_features = torch.load(f'{resume_dir}/ood_text_features.pt').to(self.device)
            with open(f'{resume_dir}/lab2cname.json', 'r') as f:
                resume_lab2cname = json.load(f)
            resume_lab2cname = {int(k): v for k, v in resume_lab2cname.items()}

            label_offset = resume_image_labels.max().item() + 1
            resume_image_labels += label_offset

            text_features = self.model.get_text_features()
            merged_text_features = torch.cat((text_features, resume_text_features), dim=0)
            if self.cfg.TRAINER.OOD_PROMPT:  # TODO: not implemented
                merged_ood_text_features = torch.cat((ood_text_features, resume_ood_text_features), dim=0)

        score_list = []
        base_acc, novel_acc = [], []
        near_ood_flag = []
        all_logits = []
        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)
            image_features, _, output = \
                self.model(input, return_feat=True, return_norm=True, posthoc=posthoc)
            
            if self.cfg.TRAINER.OOD_PROMPT:
                ood_logits = self.model.get_logits(image_features, ood_text_features, logit_scale=1.)
                # all_logits.append(torch.stack((output, ood_logits), dim=1))

                if self.cfg.TRAINER.OOD_INFER_INTEGRATE:
                    id_score = F.softmax(torch.stack((output, ood_logits), dim=1), dim=1)[:, 0, :]
                    output *= id_score.clamp(min=min_thresh)
            else:
                ood_logits = None

            if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res':
                im_feats.append(image_features.cpu())
                im_labels.append(label.cpu())

            if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
                start = data_loader.batch_size * batch_idx
                end = start + input.shape[0]
                merged_image_features = torch.cat((image_features, resume_image_features[start:end]), dim=0)
                label = torch.cat((label, resume_image_labels[start:end]), dim=0)

                output = merged_image_features @ merged_text_features.t()
                acc = output.argmax(dim=1) == label
                base_acc.append(acc[input.shape[0]:].cpu())
                novel_acc.append(acc[:input.shape[0]].cpu())

            if hasattr(self.dm.dataset, 'valid_classes'):
                # if self.cfg.TRAINER.OOD_PROMPT:
                #     raise NotImplementedError
                output[:, ~self.dm.dataset.valid_classes] = -1.
                scores = get_msp_scores(output[:, self.dm.dataset.valid_classes])
            else:
                scores, ood_flag = get_msp_scores(output, ood_logits, self.cfg.TRAINER.OOD_INFER, ret_near_ood=True)
                near_ood_flag.append(ood_flag)
            score_list.append(scores.detach().cpu().numpy())
            self.evaluator.process(output, label)

        in_scores = np.concatenate(score_list, axis=0)
        results = self.evaluator.evaluate()
        if self.cfg.TRAINER.OOD_PROMPT and len(near_ood_flag) and near_ood_flag[0] is not None:
            print('NearOOD FPR:', torch.cat(near_ood_flag).sum().item() / len(in_scores))

        if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res':
            torch.save(torch.cat(im_feats), f'{save_dir}/in_image_features.pt')
            torch.save(torch.cat(im_labels), f'{save_dir}/in_labels.pt')
            if len(all_logits):
                torch.save(torch.cat(all_logits), f'{save_dir}/in_logits_all.pt')
        
        if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
            print(f'Base: {torch.cat(base_acc).float().mean(): .4f}. Novel: {torch.cat(novel_acc).float().mean(): .4f}')

        auroc_list, aupr_list, fpr95_list = [], [], []
        ood_tpr_list = []
        save_lines = []
        for ood_name in ood_cfg[ood_type]:
            ood_set = eval(f'{ood_type}Dataset')(osp.join(data_root, ood_type), id_name=self.dm.dataset.dataset_name, 
                                    ood_name=ood_name, transform=self.test_loader.dataset.transform)
            if self.cfg.TRAINER.FEAT_AS_INPUT:
                feat_data_dir = f'{data_root}/{ood_type}/clip_feat/{ood_name}'
                if not osp.exists(feat_data_dir):
                    self.cache_feat(split='test', is_ood=True)

                ood_loader = DataLoader(
                        CLIPFeatDataset(feat_data_dir, epoch=None, split='test'), 
                        batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False,
                        num_workers=data_loader.num_workers, pin_memory=True, drop_last=False, 
                    )
            else:
                ood_loader = DataLoader(ood_set, batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False, num_workers=data_loader.num_workers,
                                            drop_last=False, pin_memory=True)
            
            ood_score_list, sc_labels_list, ood_pred_list = [], [], []
            near_ood_flag = []
            all_logits = []
            for batch_idx, batch in enumerate(tqdm(ood_loader)):
                if self.cfg.TRAINER.FEAT_AS_INPUT:
                    images, sc_labels = self.parse_batch_test(batch)
                else:
                    images, sc_labels = batch
                    images = images.to(self.device)

                image_features, _, output = \
                    self.model(images, return_feat=True, return_norm=True, posthoc=posthoc)
            
                if self.cfg.TRAINER.OOD_PROMPT:
                    ood_logits = self.model.get_logits(image_features, ood_text_features, logit_scale=1.)
                    # all_logits.append(torch.stack((output, ood_logits), dim=1))

                    if self.cfg.TRAINER.OOD_INFER_INTEGRATE:
                        id_score = F.softmax(torch.stack((output, ood_logits), dim=1), dim=1)[:, 0, :]
                        output *= id_score.clamp(min=min_thresh)
                else:
                    ood_logits = None

                if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
                    output = image_features @ merged_text_features.t()

                if hasattr(self.dm.dataset, 'valid_classes'):
                    output[:, ~self.dm.dataset.valid_classes] = -1.
                    scores = get_msp_scores(output[:, self.dm.dataset.valid_classes])
                else:
                    scores, ood_flag = get_msp_scores(output, ood_logits, self.cfg.TRAINER.OOD_INFER, ret_near_ood=True)
                    near_ood_flag.append(ood_flag)
                ood_score_list.append(scores.detach().cpu().numpy())
                sc_labels_list.append(sc_labels.cpu().numpy())
                ood_pred_list.append(output.argmax(dim=1).cpu().numpy())
            ood_scores = np.concatenate(ood_score_list, axis=0)
            sc_labels = np.concatenate(sc_labels_list, axis=0)
            ood_preds = np.concatenate(ood_pred_list, axis=0)
            fake_ood_scores = ood_scores[sc_labels>=0]
            real_ood_scores = ood_scores[sc_labels<0]
            real_in_scores = np.concatenate([in_scores, fake_ood_scores], axis=0)

            if 'cifar' in self.dm.dataset.dataset_name:  
                # compatible with SCOOD
                auroc, aupr, fpr95, thresh = get_measures(real_ood_scores, real_in_scores)
            else:  
                # compatible with NPOS
                auroc, aupr, fpr95, thresh = get_measures(-real_in_scores, -real_ood_scores)
            print('auroc: %.4f, aupr: %.4f, fpr95: %.4f' % (auroc, aupr, fpr95))

            save_lines.append('%10s auroc: %.4f, aupr: %.4f, fpr95: %.4f\n' % (ood_name, auroc, aupr, fpr95))
            auroc_list.append(auroc)
            aupr_list.append(aupr)
            fpr95_list.append(fpr95)
            if self.cfg.TRAINER.OOD_PROMPT and len(near_ood_flag) and near_ood_flag[0] is not None:
                ood_tpr = torch.cat(near_ood_flag).sum().item() / len(ood_scores)
                print('NearOOD TPR: %.4f' % ood_tpr)
                ood_tpr_list.append(ood_tpr)
            
            if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res' and len(all_logits):
                torch.save(torch.cat(all_logits), f'{save_dir}/ood_{ood_name}_logits_all.pt')

        print('\nAverage: auroc: %.4f, aupr: %.4f, fpr95: %.4f' % (np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr95_list)))
        save_lines.append('%10s auroc: %.4f, aupr: %.4f, fpr95: %.4f\n' % ('nAverage', np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr95_list)))
        if self.cfg.TRAINER.OOD_PROMPT and len(ood_tpr_list) > 1:
            print('Average: OOD-TPR: %.4f' % np.mean(ood_tpr_list))

        if model_directory != '':
            if 'ClassOOD' == ood_type:
                res_list = np.stack((auroc_list, aupr_list, fpr95_list), axis=1).reshape(-1,) * 100
                np.savetxt(f'{model_directory}/{ood_type}_results.csv', res_list, fmt='%.2f', delimiter=',')
            save_path = f'{model_directory}/{ood_type}_results.txt'
            with open(save_path, 'w+') as f:
                f.writelines(save_lines)

        return list(results.values())[0], auroc, aupr, fpr95