in scripts/train_model.py [0:0]
def check_accuracy(args, program_generator, execution_engine, baseline_model, loader):
set_mode('eval', [program_generator, execution_engine, baseline_model])
num_correct, num_samples = 0, 0
for batch in loader:
questions, _, feats, answers, programs, _ = batch
questions_var = Variable(questions.cuda(), volatile=True)
feats_var = Variable(feats.cuda(), volatile=True)
answers_var = Variable(feats.cuda(), volatile=True)
if programs[0] is not None:
programs_var = Variable(programs.cuda(), volatile=True)
scores = None # Use this for everything but PG
if args.model_type == 'PG':
vocab = utils.load_vocab(args.vocab_json)
for i in range(questions.size(0)):
program_pred = program_generator.sample(Variable(questions[i:i+1].cuda(), volatile=True))
program_pred_str = iep.preprocess.decode(program_pred, vocab['program_idx_to_token'])
program_str = iep.preprocess.decode(programs[i], vocab['program_idx_to_token'])
if program_pred_str == program_str:
num_correct += 1
num_samples += 1
elif args.model_type == 'EE':
scores = execution_engine(feats_var, programs_var)
elif args.model_type == 'PG+EE':
programs_pred = program_generator.reinforce_sample(
questions_var, argmax=True)
scores = execution_engine(feats_var, programs_pred)
elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
scores = baseline_model(questions_var, feats_var)
if scores is not None:
_, preds = scores.data.cpu().max(1)
num_correct += (preds == answers).sum()
num_samples += preds.size(0)
if num_samples >= args.num_val_samples:
break
set_mode('train', [program_generator, execution_engine, baseline_model])
acc = float(num_correct) / num_samples
return acc