def validate()

in PyTorchClassification/train.py [0:0]


def validate(val_loader, model, criterion, epoch, global_step, save_preds=False):

    with torch.no_grad():
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top3 = AverageMeter()
        top5 = AverageMeter()

        # switch to evaluate mode
        model.eval()

        end = time.time()
        pred = []
        im_ids = []

        print_log('Validate:\tTime\t\tLoss\t\tPrec@1\t\tPrec@3\t\tPrec@5')
        for i, (inputIn, im_id, target) in tqdm.tqdm(enumerate(val_loader), total=len(val_loader)):
            for j in range(0, len(inputIn)):
                input = inputIn[j].cuda()
                input_var = torch.autograd.Variable(input)
                outputNew = model(input_var)
                # In the first crop
                if j == 0:
                    output = outputNew
                # For all other crops
                else:
                    output = output + outputNew
                    #output = torch.max(output, outputNew)
                output /= len(inputIn)

            target = target.cuda()
            target_var = torch.autograd.Variable(target)
            loss = criterion(output, target_var)

            if save_preds:
                # store the top K classes for the prediction
                im_ids.append(im_id.cpu().numpy().astype(np.int))
                _, pred_inds = output.data.topk(3,1,True,True)
                pred.append(pred_inds.cpu().numpy().astype(np.int))

            # measure accuracy and record loss
            prec1, prec3, prec5 = accuracy(output.data, target, topk=(1, 3, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top3.update(prec3.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(input.size(0)/(time.time() - end))
            end = time.time()

            if i % args.print_freq == 0:
                print_log('[{0}/{1}]\t'
                      '{batch_time.val:.2f} ({batch_time.avg:.2f})\t'
                      '{loss.val:.3f} ({loss.avg:.3f})\t'
                      '{top1.val:.2f} ({top1.avg:.2f})\t'
                      '{top3.val:.2f} ({top3.avg:.2f})\t'
                      '{top5.val:.2f} ({top5.avg:.2f})'.format(i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top3=top3, top5=top5))

        writer.add_scalar('validation/loss', losses.avg, global_step)
        writer.add_scalars('validation/topk', {'top1':top1.avg,
                                                  'top3':top3.avg,
                                                  'top5':top5.avg}, global_step)

        print_log(' *** Validation summary at epoch {epoch:d}: Prec@1 {top1.avg:.3f} '.format(epoch=epoch, top1=top1) +
        'Prec@3 {top3.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.3f}'.format(top3=top3, top5=top5, loss=losses))

        if save_preds:
            return top1.avg, top3.avg, top5.avg, np.vstack(pred), np.hstack(im_ids)
        else:
            return top1.avg, top3.avg, top5.avg