in eval_zeroshot.py [0:0]
def validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc):
# switch to evaluate mode
model.eval()
total_top1 = 0
total_images = 0
all_outputs = []
all_targets = []
print('=> encoding captions')
with torch.no_grad():
text_features = []
for label in labels:
if isinstance(label, list):
texts = [t.format(l) for t in templates for l in label]
else:
texts = [t.format(label) for t in templates]
texts = tokenizer(texts).cuda(non_blocking=True)
texts = texts.view(-1, 77).contiguous()
class_embeddings = utils.get_model(model).encode_text(texts)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
class_embeddings = class_embeddings.mean(dim=0)
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
text_features.append(class_embeddings)
text_features = torch.stack(text_features, dim=0)
for images, target in val_loader:
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# encode images
image_features = utils.get_model(model).encode_image(images)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ text_features.t()
if is_acc:
# measure accuracy and record loss
pred = logits_per_image.argmax(dim=1)
correct = pred.eq(target).sum()
total_top1 += correct.item()
total_images += images.size(0)
else:
all_outputs.append(logits_per_image.cpu())
all_targets.append(target.cpu())
if is_acc:
return 100 * total_top1 / total_images
else:
return torch.cat(all_outputs), torch.cat(all_targets)