in distributed_training/src_dir/main_trainer.py [0:0]
def validate(val_loader, model, criterion, epoch, model_history, args):
batch_time = util.AverageMeter('Time', ':6.3f')
losses = util.AverageMeter('Loss', ':.4e')
top1 = util.AverageMeter('Acc@1', ':6.2f')
top5 = util.AverageMeter('Acc@5', ':6.2f')
progress = util.ProgressMeter(len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
end = time.time()
# print("**** validate *****")
test_losses = []
for batch_idx, (input, target) in enumerate((val_loader)):
input = input.to(args.device)
target = target.to(args.device)
batch_idx += 1
# compute output
with torch.no_grad():
if args.model_parallel:
output, loss = dis_util.test_step(model, criterion, input,
target)
loss = loss.reduce_mean()
test_losses.append(loss)
else:
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
if args.model_parallel:
output = torch.cat(output.outputs)
prec1, prec5 = util.accuracy(output, target, topk=(1, 5))
losses.update(util.to_python_float(loss), input.size(0))
top1.update(util.to_python_float(prec1), input.size(0))
top5.update(util.to_python_float(prec5), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# print("Validation args.rank : {}".format(args.rank))
# TODO: Change timings to mirror train().
if args.rank == 0:
print('Test: [{0}/{1}] '
'Test_Time={batch_time.val:.3f}:({batch_time.avg:.3f}), '
'Test_Speed={2:.3f}:({3:.3f}), '
'Test_Loss={loss.val:.4f}:({loss.avg:.4f}), '
'Test_Prec@1={top1.val:.3f}:({top1.avg:.3f}), '
'Test_Prec@5={top5.val:.3f}:({top5.avg:.3f})'.format(
batch_idx,
len(val_loader),
args.world_size * args.batch_size / batch_time.val,
args.world_size * args.batch_size / batch_time.avg,
batch_time=batch_time,
loss=losses,
top1=top1,
top5=top5))
model_history['val_epoch'].append(epoch)
model_history['val_batch_idx'].append(batch_idx)
model_history['val_batch_time'].append(batch_time.val)
model_history['val_losses'].append(losses.val)
model_history['val_top1'].append(top1.val)
model_history['val_top5'].append(top5.val)
model_history['val_avg_epoch'].append(epoch)
model_history['val_avg_batch_time'].append(batch_time.avg)
model_history['val_avg_losses'].append(losses.avg)
model_history['val_avg_top1'].append(top1.avg)
model_history['val_avg_top5'].append(top5.avg)
if args.model_parallel:
if args.assert_losses:
dis_util.smp_lossgather(losses.avg, args)
dis_util.smp_barrier()
return top1.avg