in trainers/catex.py [0:0]
def test_ood(self, split=None, model_directory=''):
"""A generic OOD testing pipeline."""
from tqdm import tqdm
import os
import os.path as osp
from torch.utils.data import DataLoader
import numpy as np
from ood.datasets import CLIPFeatDataset
from ood.datasets import SCOODDataset, LargeOODDataset, SemanticOODDataset, ClassOODDataset
from ood.metrics import get_msp_scores, get_measures
self.set_model_mode("eval")
self.evaluator.reset()
if split is None:
split = self.cfg.TEST.SPLIT
if self.cfg.TRAINER.FEAT_AS_INPUT:
feat_data_dir = self.dm.dataset.dataset_dir+'/clip_feat'
if not osp.exists(feat_data_dir):
self.cache_feat(split=split, is_ood=False)
data_loader = DataLoader(
CLIPFeatDataset(feat_data_dir, self.start_epoch, split='test'),
batch_size=self.test_loader.batch_size, shuffle=False,
num_workers=self.test_loader.num_workers, pin_memory=True, drop_last=False,
)
else:
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
else:
split = "test" # in case val_loader is None
data_loader = self.test_loader
lab2cname = self.dm.dataset.lab2cname
ood_cfg = {
'SCOOD': ['texture', 'svhn', 'cifar', 'tin', 'lsun', 'places365'],
'LargeOOD': ['inaturalist', 'sun', 'places', 'texture'],
}
data_root = osp.abspath(osp.expanduser(self.cfg.DATASET.ROOT))
ood_type = 'SCOOD' if 'cifar' in self.dm.dataset.dataset_name else 'LargeOOD' # LargeOOD, ClassOOD
if 'apply_' in self.cfg.TRAINER.OOD_INFER_OPTION:
posthoc = self.cfg.TRAINER.OOD_INFER_OPTION
else:
posthoc = None
if self.cfg.TRAINER.OOD_PROMPT:
if self.cfg.TRAINER.OOD_PROMPT_NUM > 1:
ood_text_features = torch.stack([self.model.get_text_features(ood_prompt=True, ood_prompt_idx=i) for i in range(self.cfg.TRAINER.OOD_PROMPT_NUM)])
else:
ood_text_features = self.model.get_text_features(ood_prompt=True)
if self.cfg.TRAINER.CATEX.CTX_INIT:
assert self.cfg.TRAINER.CATEX.CTX_INIT == 'ensemble_learned'
ood_text_features = self.model.prompt_ensemble(ood_text_features)
min_thresh = 0.51 if any(flag in model_directory for flag in ['/imagenet/', '/imagenet100-MCM-SCTX8-Orth/']) else 0.5
self.model.text_feature_ensemble = self.model.get_text_features()
print(f"Evaluate on the *{split}* set")
if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res':
save_dir = f'{model_directory}/restore'
os.makedirs(save_dir, exist_ok=True)
with open(f'{save_dir}/lab2cname.json', 'w+') as f:
json.dump(lab2cname, f, indent=4)
text_features = self.model.get_text_features()
torch.save(text_features.cpu(), f'{save_dir}/in_text_features.pt')
if self.cfg.TRAINER.OOD_PROMPT:
torch.save(ood_text_features.cpu(), f'{save_dir}/ood_text_features.pt')
im_feats, im_labels = [], []
if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
resume_dir = 'weights/imagenet100-MCM/CATEX/vit_b16_ep50_-1shots/nctx16_cscTrue_ctpend/seed1/restore'
resume_image_features = torch.load(f'{resume_dir}/in_image_features.pt').to(self.device)
resume_image_labels = torch.load(f'{resume_dir}/in_labels.pt').to(self.device)
resume_text_features = torch.load(f'{resume_dir}/in_text_features.pt').to(self.device)
if self.cfg.TRAINER.OOD_PROMPT:
resume_ood_text_features = torch.load(f'{resume_dir}/ood_text_features.pt').to(self.device)
with open(f'{resume_dir}/lab2cname.json', 'r') as f:
resume_lab2cname = json.load(f)
resume_lab2cname = {int(k): v for k, v in resume_lab2cname.items()}
label_offset = resume_image_labels.max().item() + 1
resume_image_labels += label_offset
text_features = self.model.get_text_features()
merged_text_features = torch.cat((text_features, resume_text_features), dim=0)
if self.cfg.TRAINER.OOD_PROMPT: # TODO: not implemented
merged_ood_text_features = torch.cat((ood_text_features, resume_ood_text_features), dim=0)
score_list = []
base_acc, novel_acc = [], []
near_ood_flag = []
all_logits = []
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
image_features, _, output = \
self.model(input, return_feat=True, return_norm=True, posthoc=posthoc)
if self.cfg.TRAINER.OOD_PROMPT:
ood_logits = self.model.get_logits(image_features, ood_text_features, logit_scale=1.)
# all_logits.append(torch.stack((output, ood_logits), dim=1))
if self.cfg.TRAINER.OOD_INFER_INTEGRATE:
id_score = F.softmax(torch.stack((output, ood_logits), dim=1), dim=1)[:, 0, :]
output *= id_score.clamp(min=min_thresh)
else:
ood_logits = None
if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res':
im_feats.append(image_features.cpu())
im_labels.append(label.cpu())
if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
start = data_loader.batch_size * batch_idx
end = start + input.shape[0]
merged_image_features = torch.cat((image_features, resume_image_features[start:end]), dim=0)
label = torch.cat((label, resume_image_labels[start:end]), dim=0)
output = merged_image_features @ merged_text_features.t()
acc = output.argmax(dim=1) == label
base_acc.append(acc[input.shape[0]:].cpu())
novel_acc.append(acc[:input.shape[0]].cpu())
if hasattr(self.dm.dataset, 'valid_classes'):
# if self.cfg.TRAINER.OOD_PROMPT:
# raise NotImplementedError
output[:, ~self.dm.dataset.valid_classes] = -1.
scores = get_msp_scores(output[:, self.dm.dataset.valid_classes])
else:
scores, ood_flag = get_msp_scores(output, ood_logits, self.cfg.TRAINER.OOD_INFER, ret_near_ood=True)
near_ood_flag.append(ood_flag)
score_list.append(scores.detach().cpu().numpy())
self.evaluator.process(output, label)
in_scores = np.concatenate(score_list, axis=0)
results = self.evaluator.evaluate()
if self.cfg.TRAINER.OOD_PROMPT and len(near_ood_flag) and near_ood_flag[0] is not None:
print('NearOOD FPR:', torch.cat(near_ood_flag).sum().item() / len(in_scores))
if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res':
torch.save(torch.cat(im_feats), f'{save_dir}/in_image_features.pt')
torch.save(torch.cat(im_labels), f'{save_dir}/in_labels.pt')
if len(all_logits):
torch.save(torch.cat(all_logits), f'{save_dir}/in_logits_all.pt')
if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
print(f'Base: {torch.cat(base_acc).float().mean(): .4f}. Novel: {torch.cat(novel_acc).float().mean(): .4f}')
auroc_list, aupr_list, fpr95_list = [], [], []
ood_tpr_list = []
save_lines = []
for ood_name in ood_cfg[ood_type]:
ood_set = eval(f'{ood_type}Dataset')(osp.join(data_root, ood_type), id_name=self.dm.dataset.dataset_name,
ood_name=ood_name, transform=self.test_loader.dataset.transform)
if self.cfg.TRAINER.FEAT_AS_INPUT:
feat_data_dir = f'{data_root}/{ood_type}/clip_feat/{ood_name}'
if not osp.exists(feat_data_dir):
self.cache_feat(split='test', is_ood=True)
ood_loader = DataLoader(
CLIPFeatDataset(feat_data_dir, epoch=None, split='test'),
batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False,
num_workers=data_loader.num_workers, pin_memory=True, drop_last=False,
)
else:
ood_loader = DataLoader(ood_set, batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False, num_workers=data_loader.num_workers,
drop_last=False, pin_memory=True)
ood_score_list, sc_labels_list, ood_pred_list = [], [], []
near_ood_flag = []
all_logits = []
for batch_idx, batch in enumerate(tqdm(ood_loader)):
if self.cfg.TRAINER.FEAT_AS_INPUT:
images, sc_labels = self.parse_batch_test(batch)
else:
images, sc_labels = batch
images = images.to(self.device)
image_features, _, output = \
self.model(images, return_feat=True, return_norm=True, posthoc=posthoc)
if self.cfg.TRAINER.OOD_PROMPT:
ood_logits = self.model.get_logits(image_features, ood_text_features, logit_scale=1.)
# all_logits.append(torch.stack((output, ood_logits), dim=1))
if self.cfg.TRAINER.OOD_INFER_INTEGRATE:
id_score = F.softmax(torch.stack((output, ood_logits), dim=1), dim=1)[:, 0, :]
output *= id_score.clamp(min=min_thresh)
else:
ood_logits = None
if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res':
output = image_features @ merged_text_features.t()
if hasattr(self.dm.dataset, 'valid_classes'):
output[:, ~self.dm.dataset.valid_classes] = -1.
scores = get_msp_scores(output[:, self.dm.dataset.valid_classes])
else:
scores, ood_flag = get_msp_scores(output, ood_logits, self.cfg.TRAINER.OOD_INFER, ret_near_ood=True)
near_ood_flag.append(ood_flag)
ood_score_list.append(scores.detach().cpu().numpy())
sc_labels_list.append(sc_labels.cpu().numpy())
ood_pred_list.append(output.argmax(dim=1).cpu().numpy())
ood_scores = np.concatenate(ood_score_list, axis=0)
sc_labels = np.concatenate(sc_labels_list, axis=0)
ood_preds = np.concatenate(ood_pred_list, axis=0)
fake_ood_scores = ood_scores[sc_labels>=0]
real_ood_scores = ood_scores[sc_labels<0]
real_in_scores = np.concatenate([in_scores, fake_ood_scores], axis=0)
if 'cifar' in self.dm.dataset.dataset_name:
# compatible with SCOOD
auroc, aupr, fpr95, thresh = get_measures(real_ood_scores, real_in_scores)
else:
# compatible with NPOS
auroc, aupr, fpr95, thresh = get_measures(-real_in_scores, -real_ood_scores)
print('auroc: %.4f, aupr: %.4f, fpr95: %.4f' % (auroc, aupr, fpr95))
save_lines.append('%10s auroc: %.4f, aupr: %.4f, fpr95: %.4f\n' % (ood_name, auroc, aupr, fpr95))
auroc_list.append(auroc)
aupr_list.append(aupr)
fpr95_list.append(fpr95)
if self.cfg.TRAINER.OOD_PROMPT and len(near_ood_flag) and near_ood_flag[0] is not None:
ood_tpr = torch.cat(near_ood_flag).sum().item() / len(ood_scores)
print('NearOOD TPR: %.4f' % ood_tpr)
ood_tpr_list.append(ood_tpr)
if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res' and len(all_logits):
torch.save(torch.cat(all_logits), f'{save_dir}/ood_{ood_name}_logits_all.pt')
print('\nAverage: auroc: %.4f, aupr: %.4f, fpr95: %.4f' % (np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr95_list)))
save_lines.append('%10s auroc: %.4f, aupr: %.4f, fpr95: %.4f\n' % ('nAverage', np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr95_list)))
if self.cfg.TRAINER.OOD_PROMPT and len(ood_tpr_list) > 1:
print('Average: OOD-TPR: %.4f' % np.mean(ood_tpr_list))
if model_directory != '':
if 'ClassOOD' == ood_type:
res_list = np.stack((auroc_list, aupr_list, fpr95_list), axis=1).reshape(-1,) * 100
np.savetxt(f'{model_directory}/{ood_type}_results.csv', res_list, fmt='%.2f', delimiter=',')
save_path = f'{model_directory}/{ood_type}_results.txt'
with open(save_path, 'w+') as f:
f.writelines(save_lines)
return list(results.values())[0], auroc, aupr, fpr95