def validate()

in distributed_training/src_dir/main_trainer.py [0:0]


def validate(val_loader, model, criterion, epoch, model_history, args):
    batch_time = util.AverageMeter('Time', ':6.3f')
    losses = util.AverageMeter('Loss', ':.4e')
    top1 = util.AverageMeter('Acc@1', ':6.2f')
    top5 = util.AverageMeter('Acc@5', ':6.2f')
    progress = util.ProgressMeter(len(val_loader),
                                  [batch_time, losses, top1, top5],
                                  prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    end = time.time()

    #     print("**** validate *****")
    test_losses = []
    for batch_idx, (input, target) in enumerate((val_loader)):
        input = input.to(args.device)
        target = target.to(args.device)

        batch_idx += 1
        # compute output
        with torch.no_grad():
            if args.model_parallel:
                output, loss = dis_util.test_step(model, criterion, input,
                                                  target)
                loss = loss.reduce_mean()
                test_losses.append(loss)
            else:
                output = model(input)
                loss = criterion(output, target)

        # measure accuracy and record loss
        if args.model_parallel:
            output = torch.cat(output.outputs)

        prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

        losses.update(util.to_python_float(loss), input.size(0))
        top1.update(util.to_python_float(prec1), input.size(0))
        top5.update(util.to_python_float(prec5), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #         print("Validation args.rank : {}".format(args.rank))
        # TODO:  Change timings to mirror train().
        if args.rank == 0:
            print('Test: [{0}/{1}]  '
                  'Test_Time={batch_time.val:.3f}:({batch_time.avg:.3f}), '
                  'Test_Speed={2:.3f}:({3:.3f}), '
                  'Test_Loss={loss.val:.4f}:({loss.avg:.4f}), '
                  'Test_Prec@1={top1.val:.3f}:({top1.avg:.3f}), '
                  'Test_Prec@5={top5.val:.3f}:({top5.avg:.3f})'.format(
                      batch_idx,
                      len(val_loader),
                      args.world_size * args.batch_size / batch_time.val,
                      args.world_size * args.batch_size / batch_time.avg,
                      batch_time=batch_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))
            model_history['val_epoch'].append(epoch)
            model_history['val_batch_idx'].append(batch_idx)
            model_history['val_batch_time'].append(batch_time.val)
            model_history['val_losses'].append(losses.val)
            model_history['val_top1'].append(top1.val)
            model_history['val_top5'].append(top5.val)

    model_history['val_avg_epoch'].append(epoch)
    model_history['val_avg_batch_time'].append(batch_time.avg)
    model_history['val_avg_losses'].append(losses.avg)
    model_history['val_avg_top1'].append(top1.avg)
    model_history['val_avg_top5'].append(top5.avg)

    if args.model_parallel:
        if args.assert_losses:
            dis_util.smp_lossgather(losses.avg, args)
        dis_util.smp_barrier()
        
    return top1.avg