in trainers/catex.py [0:0]
def cache_feat(self, split='train', is_ood=True):
"""A generic OOD testing pipeline."""
self.set_model_mode("eval")
self.evaluator.reset()
if split == 'train':
data_loader = self.train_loader_x
max_epoch = self.max_epoch
else:
data_loader = self.test_loader
max_epoch = 1
if is_ood:
from ood.datasets import LargeOODDataset
from torch.utils.data import DataLoader
data_root = osp.join(osp.abspath(osp.expanduser(self.cfg.DATASET.ROOT)), 'LargeOOD')
for ood_name in ['inaturalist', 'sun', 'places', 'texture']:
ood_set = LargeOODDataset(data_root, id_name=self.dm.dataset.dataset_name,
ood_name=ood_name, transform=self.test_loader.dataset.transform)
ood_loader = DataLoader(ood_set, batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False, num_workers=self.test_loader.num_workers,
drop_last=False, pin_memory=True)
save_dir = f'{data_root}/clip_feat/{ood_name}'
os.makedirs(save_dir, exist_ok=True)
features, labels, paths = [], [], []
cnt = 0
for input, label in tqdm(ood_loader, desc='Caching image features'):
input = input.to(self.device)
label = label.to(self.device)
image_features = self.model.image_encoder(input.type(self.model.dtype)).detach()
features.append(image_features.cpu())
labels.append(label.cpu())
for i in range(len(input)):
paths.append(ood_set.samples[cnt+i][0])
cnt += len(input)
torch.save(torch.cat(features).half(), f'{save_dir}/test_image_features.pt')
torch.save(torch.cat(labels).half(), f'{save_dir}/test_labels.pt')
with open(f'{save_dir}/test_paths.txt', 'w+') as f:
f.writelines([p + '\n' for p in paths])
else:
save_dir = f'{self.dm.dataset.dataset_dir}/clip_feat'
os.makedirs(save_dir, exist_ok=True)
for self.epoch in range(self.start_epoch, max_epoch):
features, labels, paths = [], [], []
for batch_idx, batch in enumerate(tqdm(data_loader, desc=f"Caching image features: {split} {self.epoch+1}/{max_epoch}: ")):
input = batch["img"].to(self.device)
label = batch["label"].to(self.device)
image_features = self.model.image_encoder(input.type(self.model.dtype)).detach()
features.append(image_features.cpu())
labels.append(label.cpu())
paths.extend(batch["impath"])
torch.save(torch.cat(features).half(), f'{save_dir}/ep{self.epoch}_{split}_image_features.pt')
torch.save(torch.cat(labels).half(), f'{save_dir}/ep{self.epoch}_{split}_labels.pt')
with open(f'{save_dir}/ep{self.epoch}_{split}_paths.txt', 'w+') as f:
f.writelines([p + '\n' for p in paths])