def test()

in data_augmentation/test.py [0:0]


def test(loader, model, args, per_img_dict, disable_augerino=True, compute_per_sample=False):
    # switch to evaluate mode
    model.eval()

    losses = []
    top1 = AverageMeter("Acc@1", ":6.4f")
    top5 = AverageMeter("Acc@5", ":6.4f")
    count = 0
    acc_per_class = torch.zeros(args.num_classes)
    count_per_class = torch.zeros(args.num_classes)
    if args.augerino:
        if disable_augerino:
            if isinstance(model, nn.parallel.DistributedDataParallel):
                model.module.disabled = True
            elif isinstance(model, AugAveragedModel):
                model.disabled = True
            print("Disabling Augerino")
        else:
            if isinstance(model, nn.parallel.DistributedDataParallel):
                model.module.disabled = False
            elif isinstance(model, AugAveragedModel):
                model.disabled = False
            print("Enabling Augerino")

    with torch.no_grad():
        for i, (images, target, pathes) in enumerate(loader):
            bs = images.size(0)
            count+= bs
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(args.gpu, non_blocking=True)
            # compute output
            if args.augerino and args.inv_per_class:
                output = model(images,target)
            else:
                output = model(images)

            # measure accuracy 
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            if compute_per_sample:
                _, pred = output.topk(1, 1, True, True)
                pred = pred.t()
                correct = pred.eq(target.view(1, -1).expand_as(pred))
                correct=correct.float().squeeze(0)
                for j in range(bs):
                    per_img_dict[pathes[j]] = correct[j]

            acc1_class, count_class = acc1_per_class(args, output, target)

            acc1 = acc1 / float(bs) * 100.0
            acc5 = acc5 / float(bs) * 100.0

            top1.update(acc1[0], bs)
            top5.update(acc5[0], bs)

            acc_per_class += acc1_class * 100.0
            count_per_class += count_class

    return top1.avg.item(), top5.avg.item(), acc_per_class/count_per_class