in data_augmentation/test.py [0:0]
def test(loader, model, args, per_img_dict, disable_augerino=True, compute_per_sample=False):
# switch to evaluate mode
model.eval()
losses = []
top1 = AverageMeter("Acc@1", ":6.4f")
top5 = AverageMeter("Acc@5", ":6.4f")
count = 0
acc_per_class = torch.zeros(args.num_classes)
count_per_class = torch.zeros(args.num_classes)
if args.augerino:
if disable_augerino:
if isinstance(model, nn.parallel.DistributedDataParallel):
model.module.disabled = True
elif isinstance(model, AugAveragedModel):
model.disabled = True
print("Disabling Augerino")
else:
if isinstance(model, nn.parallel.DistributedDataParallel):
model.module.disabled = False
elif isinstance(model, AugAveragedModel):
model.disabled = False
print("Enabling Augerino")
with torch.no_grad():
for i, (images, target, pathes) in enumerate(loader):
bs = images.size(0)
count+= bs
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
# compute output
if args.augerino and args.inv_per_class:
output = model(images,target)
else:
output = model(images)
# measure accuracy
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if compute_per_sample:
_, pred = output.topk(1, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
correct=correct.float().squeeze(0)
for j in range(bs):
per_img_dict[pathes[j]] = correct[j]
acc1_class, count_class = acc1_per_class(args, output, target)
acc1 = acc1 / float(bs) * 100.0
acc5 = acc5 / float(bs) * 100.0
top1.update(acc1[0], bs)
top5.update(acc5[0], bs)
acc_per_class += acc1_class * 100.0
count_per_class += count_class
return top1.avg.item(), top5.avg.item(), acc_per_class/count_per_class