in training/train_vqa.py [0:0]
def eval(rank, args, shared_model):
torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))
if args.input_type == 'ques':
model_kwargs = {'vocab': load_vocab(args.vocab_json)}
model = VqaLstmModel(**model_kwargs)
elif args.input_type == 'ques,image':
model_kwargs = {'vocab': load_vocab(args.vocab_json)}
model = VqaLstmCnnAttentionModel(**model_kwargs)
lossFn = torch.nn.CrossEntropyLoss().cuda()
eval_loader_kwargs = {
'questions_h5': getattr(args, args.eval_split + '_h5'),
'data_json': args.data_json,
'vocab': args.vocab_json,
'batch_size': 1,
'input_type': args.input_type,
'num_frames': args.num_frames,
'split': args.eval_split,
'max_threads_per_gpu': args.max_threads_per_gpu,
'gpu_id': args.gpus[rank%len(args.gpus)],
'to_cache': args.cache
}
eval_loader = EqaDataLoader(**eval_loader_kwargs)
print('eval_loader has %d samples' % len(eval_loader.dataset))
args.output_log_path = os.path.join(args.log_dir,
'eval_' + str(rank) + '.json')
t, epoch, best_eval_acc = 0, 0, 0
while epoch < int(args.max_epochs):
model.load_state_dict(shared_model.state_dict())
model.eval()
metrics = VqaMetric(
info={'split': args.eval_split},
metric_names=[
'loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'
],
log_json=args.output_log_path)
if args.input_type == 'ques':
for batch in eval_loader:
t += 1
model.cuda()
idx, questions, answers = batch
questions_var = Variable(questions.cuda())
answers_var = Variable(answers.cuda())
scores = model(questions_var)
loss = lossFn(scores, answers_var)
# update metrics
accuracy, ranks = metrics.compute_ranks(
scores.data.cpu(), answers)
metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])
print(metrics.get_stat_string(mode=0))
elif args.input_type == 'ques,image':
done = False
all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded()
while done == False:
for batch in eval_loader:
t += 1
model.cuda()
idx, questions, answers, images, _, _, _ = batch
questions_var = Variable(questions.cuda())
answers_var = Variable(answers.cuda())
images_var = Variable(images.cuda())
scores, att_probs = model(images_var, questions_var)
loss = lossFn(scores, answers_var)
# update metrics
accuracy, ranks = metrics.compute_ranks(
scores.data.cpu(), answers)
metrics.update(
[loss.data[0], accuracy, ranks, 1.0 / ranks])
print(metrics.get_stat_string(mode=0))
if all_envs_loaded == False:
eval_loader.dataset._load_envs()
if len(eval_loader.dataset.pruned_env_set) == 0:
done = True
else:
done = True
epoch += 1
# checkpoint if best val accuracy
if metrics.metrics[1][0] > best_eval_acc:
best_eval_acc = metrics.metrics[1][0]
if epoch % args.eval_every == 0 and args.log == True:
metrics.dump_log()
model_state = get_state(model)
if args.checkpoint_path != False:
ad = checkpoint['args']
else:
ad = args.__dict__
checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}
checkpoint_path = '%s/epoch_%d_accuracy_%.04f.pt' % (
args.checkpoint_dir, epoch, best_eval_acc)
print('Saving checkpoint to %s' % checkpoint_path)
torch.save(checkpoint, checkpoint_path)
print('[best_eval_accuracy:%.04f]' % best_eval_acc)