in PyTorchClassification/train.py [0:0]
def validate(val_loader, model, criterion, epoch, global_step, save_preds=False):
with torch.no_grad():
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top3 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
pred = []
im_ids = []
print_log('Validate:\tTime\t\tLoss\t\tPrec@1\t\tPrec@3\t\tPrec@5')
for i, (inputIn, im_id, target) in tqdm.tqdm(enumerate(val_loader), total=len(val_loader)):
for j in range(0, len(inputIn)):
input = inputIn[j].cuda()
input_var = torch.autograd.Variable(input)
outputNew = model(input_var)
# In the first crop
if j == 0:
output = outputNew
# For all other crops
else:
output = output + outputNew
#output = torch.max(output, outputNew)
output /= len(inputIn)
target = target.cuda()
target_var = torch.autograd.Variable(target)
loss = criterion(output, target_var)
if save_preds:
# store the top K classes for the prediction
im_ids.append(im_id.cpu().numpy().astype(np.int))
_, pred_inds = output.data.topk(3,1,True,True)
pred.append(pred_inds.cpu().numpy().astype(np.int))
# measure accuracy and record loss
prec1, prec3, prec5 = accuracy(output.data, target, topk=(1, 3, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top3.update(prec3.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(input.size(0)/(time.time() - end))
end = time.time()
if i % args.print_freq == 0:
print_log('[{0}/{1}]\t'
'{batch_time.val:.2f} ({batch_time.avg:.2f})\t'
'{loss.val:.3f} ({loss.avg:.3f})\t'
'{top1.val:.2f} ({top1.avg:.2f})\t'
'{top3.val:.2f} ({top3.avg:.2f})\t'
'{top5.val:.2f} ({top5.avg:.2f})'.format(i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top3=top3, top5=top5))
writer.add_scalar('validation/loss', losses.avg, global_step)
writer.add_scalars('validation/topk', {'top1':top1.avg,
'top3':top3.avg,
'top5':top5.avg}, global_step)
print_log(' *** Validation summary at epoch {epoch:d}: Prec@1 {top1.avg:.3f} '.format(epoch=epoch, top1=top1) +
'Prec@3 {top3.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.3f}'.format(top3=top3, top5=top5, loss=losses))
if save_preds:
return top1.avg, top3.avg, top5.avg, np.vstack(pred), np.hstack(im_ids)
else:
return top1.avg, top3.avg, top5.avg