def validate()

in data_augmentation/my_training.py [0:0]


def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter("Time", ":6.3f")
    losses = AverageMeter("Loss", ":.4e")
    top1 = AverageMeter("Acc@1", ":6.4f")
    top5 = AverageMeter("Acc@5", ":6.4f")
    progress = ProgressMeter(
        len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
    )

    # switch to evaluate mode
    model.eval()

    if args.augerino and args.disable_at_valid:
        if isinstance(model, nn.parallel.DistributedDataParallel):
            model.module.disabled = True
        elif isinstance(model, AugAveragedModel):
            model.disabled = True
        print("Disabling Augerino")

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):

            if torch.cuda.is_available():
                images = images.cuda()
                target = target.cuda(non_blocking=True)

            # compute output
            if args.augerino and args.inv_per_class:
                output = model(images, target)
            else:
                output = model(images)
            if args.augerino:
                loss = criterion(output, target, model, args, reg=args.aug_reg)
            else:
                loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            acc1 = acc1 / float(images.size(0)) * 100.0
            acc5 = acc5 / float(images.size(0)) * 100.0

            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(
            " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
        )
    if args.augerino and args.disable_at_valid:
        if isinstance(model, nn.parallel.DistributedDataParallel):
            model.module.disabled = False
        elif isinstance(model, AugAveragedModel):
            model.disabled = False
    return top1.avg.item(), top5.avg.item(), losses.avg