in scripts/run_model.py [0:0]
def run_our_model_batch(args, program_generator, execution_engine, loader, dtype):
program_generator.type(dtype)
program_generator.eval()
execution_engine.type(dtype)
execution_engine.eval()
all_scores, all_programs = [], []
all_probs = []
num_correct, num_samples = 0, 0
for batch in loader:
questions, images, feats, answers, programs, program_lists = batch
questions_var = Variable(questions.type(dtype).long(), volatile=True)
feats_var = Variable(feats.type(dtype), volatile=True)
programs_pred = program_generator.reinforce_sample(
questions_var,
temperature=args.temperature,
argmax=(args.sample_argmax == 1))
if args.use_gt_programs == 1:
scores = execution_engine(feats_var, program_lists)
else:
scores = execution_engine(feats_var, programs_pred)
probs = F.softmax(scores)
_, preds = scores.data.cpu().max(1)
all_programs.append(programs_pred.data.cpu().clone())
all_scores.append(scores.data.cpu().clone())
all_probs.append(probs.data.cpu().clone())
num_correct += (preds == answers).sum()
num_samples += preds.size(0)
print('Ran %d samples' % num_samples)
acc = float(num_correct) / num_samples
print('Got %d / %d = %.2f correct' % (num_correct, num_samples, 100 * acc))
all_scores = torch.cat(all_scores, 0)
all_probs = torch.cat(all_probs, 0)
all_programs = torch.cat(all_programs, 0)
if args.output_h5 is not None:
print('Writing output to "%s"' % args.output_h5)
with h5py.File(args.output_h5, 'w') as fout:
fout.create_dataset('scores', data=all_scores.numpy())
fout.create_dataset('probs', data=all_probs.numpy())
fout.create_dataset('predicted_programs', data=all_programs.numpy())