in eval_zeroshot.py [0:0]
def main(args):
# optionally resume from a checkpoint (takes precedence over autoresume)
if args.resume:
ckpt_path = args.resume
elif os.path.isfile(os.path.join(args.output_dir, 'checkpoint_best.pt')):
ckpt_path = os.path.join(args.output_dir, 'checkpoint_best.pt')
else:
raise Exception('no checkpoint found')
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
state_dict[k.replace('module.', '')] = v
# create model
old_args = ckpt['args']
print("=> creating model: {}".format(old_args.model))
model = getattr(models, old_args.model)(rand_embed=False,
ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
model.cuda()
model.load_state_dict(state_dict, strict=True)
print("=> loaded resume checkpoint '{}' (epoch {})".format(args.resume, ckpt['epoch']))
cudnn.benchmark = True
cwd = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(cwd, 'dataset_catalog.json')) as f:
catalog = json.load(f)
with open(os.path.join(cwd, 'templates.json')) as f:
all_templates = json.load(f)
with open(os.path.join(cwd, 'labels.json')) as f:
all_labels = json.load(f)
# Data loading code
print("=> creating dataset")
tokenizer = SimpleTokenizer()
val_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
lambda x: x.convert('RGB'),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
results = []
for d in catalog:
print('Evaluating {}'.format(d))
val_dataset = datasets.get_downstream_dataset(catalog, name=d, is_train=False, transform=val_transform)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, drop_last=False)
templates = all_templates[d]
labels = all_labels[d]
is_acc = d not in ['aircraft', 'pets', 'caltech101', 'flowers', 'kinetics700_frames', 'hateful_memes']
acc_or_outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc)
if d in ['aircraft', 'pets', 'caltech101', 'flowers']:
metric = mean_per_class(*acc_or_outputs)
elif d == 'kinetics700_frames':
top1, top5 = accuracy(*acc_or_outputs, topk=(1, 5))
metric = (top1 + top5) / 2
metric = metric.item()
elif d == 'hateful_memes':
metric = roc_auc(*acc_or_outputs)
else:
metric = acc_or_outputs
results.append(metric)
print('metric:', metric)
print('all results:')
for x in results:
print('{:.1f}'.format(x))