in Classification/net_util.py [0:0]
def eval_net(net,args):
top1 = AverageMeter()
top5 = AverageMeter()
losses = AverageMeter()
if args.validating:
print('Validating at epoch {}'.format(args.epoch + 1))
if args.testing:
print('Testing at epoch {}'.format(args.epoch + 1))
if not args.__contains__('validating'):
args.validating = False
if not args.__contains__('testing'):
args.testing = False
net.eval()
total = 1e-3
total_time = 0
end_time = time.time()
for batch_idx, (inputs, targets) in enumerate(args.data_loader):
with torch.no_grad():
if args.use_gpu:
targets = targets.cuda()
outputs = net(inputs)
if type(outputs) is list:
outputs = outputs[-1]
if args.loss == 'CE':
loss = args.criterion(outputs, targets) # .mean()
elif args.loss == 'L2':
from util import targets_to_one_hot
targets_one_hot = targets_to_one_hot(targets, args.num_outputs)
loss = args.criterion(outputs, targets_one_hot)*args.num_outputs*0.5
losses.update(loss.item(), inputs.size(0))
prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
top1.update(prec1[0].item(), inputs.size(0))
top5.update(prec5[0].item(), inputs.size(0))
total_time += (time.time() - end_time)
end_time = time.time()
if args.msg:
print('Loss: %.3f | top1: %.3f%% ,top5: %.3f%%'
% (losses.avg, prec1[0].item(), prec5[0].item()))
if args.testing:
args.test_losses.append(losses.avg)
args.test_accuracies.append(top1.avg)
args.test_epoch_logger.log({
'epoch': (args.epoch + 1),
'loss': losses.avg,
'top1': top1.avg,
'top5': top5.avg,
'time': total_time,
})
if args.validating:
args.valid_losses.append(losses.avg)
args.valid_accuracies.append(top1.avg)
args.valid_epoch_logger.log({
'epoch': (args.epoch + 1),
'loss': losses.avg,
'top1': top1.avg,
'top5': top5.avg,
'time': total_time,
})
# Save checkpoint.
is_best=(top1.avg > args.best_acc)
if is_best:
args.best_acc = top1.avg
states = {
'state_dict': net.module.state_dict() if hasattr(net,'module') else net.state_dict(),
'epoch': args.epoch+1,
'arch': args.arch,
'best_acc': args.best_acc,
'train_losses': args.train_losses,
'optimizer': args.current_optimizer.state_dict()
}
if args.__contains__('acc'):
states['acc']=top1.avg,
if args.__contains__('valid_losses'):
states['valid_losses']=args.valid_losses
if args.__contains__('test_losses'):
states['test_losses'] = args.test_losses
if (args.checkpoint_epoch > 0):
if not os.path.isdir(args.checkpoint_path):
os.mkdir(args.checkpoint_path)
save_file_path = os.path.join(args.checkpoint_path, 'checkpoint.pth.tar')
torch.save(states, save_file_path)
if is_best:
shutil.copyfile(save_file_path, os.path.join(args.checkpoint_path,'model_best.pth.tar'))
print('Loss: %.3f | top1: %.3f%%, top5: %.3f%%, elasped time: %3.f seconds. Best Acc: %.3f%%'
% (losses.avg , top1.avg, top5.avg, total_time, args.best_acc))