def cache_feat()

in trainers/catex.py [0:0]


    def cache_feat(self, split='train', is_ood=True):
        """A generic OOD testing pipeline."""

        self.set_model_mode("eval")
        self.evaluator.reset()
        if split == 'train':
            data_loader = self.train_loader_x  
            max_epoch = self.max_epoch
        else:
            data_loader = self.test_loader
            max_epoch = 1

        if is_ood:
            from ood.datasets import LargeOODDataset
            from torch.utils.data import DataLoader
            data_root = osp.join(osp.abspath(osp.expanduser(self.cfg.DATASET.ROOT)), 'LargeOOD')
            for ood_name in ['inaturalist', 'sun', 'places', 'texture']:
                ood_set = LargeOODDataset(data_root, id_name=self.dm.dataset.dataset_name, 
                                            ood_name=ood_name, transform=self.test_loader.dataset.transform)
                ood_loader = DataLoader(ood_set, batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False, num_workers=self.test_loader.num_workers,
                                            drop_last=False, pin_memory=True)

                save_dir = f'{data_root}/clip_feat/{ood_name}'
                os.makedirs(save_dir, exist_ok=True)

                features, labels, paths = [], [], []
                cnt = 0
                for input, label in tqdm(ood_loader, desc='Caching image features'):
                    input = input.to(self.device)
                    label = label.to(self.device)

                    image_features = self.model.image_encoder(input.type(self.model.dtype)).detach()

                    features.append(image_features.cpu())
                    labels.append(label.cpu())
                    for i in range(len(input)):
                        paths.append(ood_set.samples[cnt+i][0])
                    cnt += len(input)

                torch.save(torch.cat(features).half(), f'{save_dir}/test_image_features.pt')
                torch.save(torch.cat(labels).half(), f'{save_dir}/test_labels.pt')
                with open(f'{save_dir}/test_paths.txt', 'w+') as f:
                    f.writelines([p + '\n' for p in paths])
        else:
            save_dir = f'{self.dm.dataset.dataset_dir}/clip_feat'
            os.makedirs(save_dir, exist_ok=True)

            for self.epoch in range(self.start_epoch, max_epoch):
                features, labels, paths = [], [], []
                for batch_idx, batch in enumerate(tqdm(data_loader, desc=f"Caching image features: {split} {self.epoch+1}/{max_epoch}: ")):
                    input = batch["img"].to(self.device)
                    label = batch["label"].to(self.device)

                    image_features = self.model.image_encoder(input.type(self.model.dtype)).detach()

                    features.append(image_features.cpu())
                    labels.append(label.cpu())
                    paths.extend(batch["impath"])

                torch.save(torch.cat(features).half(), f'{save_dir}/ep{self.epoch}_{split}_image_features.pt')
                torch.save(torch.cat(labels).half(), f'{save_dir}/ep{self.epoch}_{split}_labels.pt')
                with open(f'{save_dir}/ep{self.epoch}_{split}_paths.txt', 'w+') as f:
                    f.writelines([p + '\n' for p in paths])