in scripts/train_model.py [0:0]
def train_loop(args, train_loader, val_loader):
vocab = utils.load_vocab(args.vocab_json)
program_generator, pg_kwargs, pg_optimizer = None, None, None
execution_engine, ee_kwargs, ee_optimizer = None, None, None
baseline_model, baseline_kwargs, baseline_optimizer = None, None, None
baseline_type = None
pg_best_state, ee_best_state, baseline_best_state = None, None, None
# Set up model
if args.model_type == 'PG' or args.model_type == 'PG+EE':
program_generator, pg_kwargs = get_program_generator(args)
pg_optimizer = torch.optim.Adam(program_generator.parameters(),
lr=args.learning_rate)
print('Here is the program generator:')
print(program_generator)
if args.model_type == 'EE' or args.model_type == 'PG+EE':
execution_engine, ee_kwargs = get_execution_engine(args)
ee_optimizer = torch.optim.Adam(execution_engine.parameters(),
lr=args.learning_rate)
print('Here is the execution engine:')
print(execution_engine)
if args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
baseline_model, baseline_kwargs = get_baseline_model(args)
params = baseline_model.parameters()
if args.baseline_train_only_rnn == 1:
params = baseline_model.rnn.parameters()
baseline_optimizer = torch.optim.Adam(params, lr=args.learning_rate)
print('Here is the baseline model')
print(baseline_model)
baseline_type = args.model_type
loss_fn = torch.nn.CrossEntropyLoss().cuda()
stats = {
'train_losses': [], 'train_rewards': [], 'train_losses_ts': [],
'train_accs': [], 'val_accs': [], 'val_accs_ts': [],
'best_val_acc': -1, 'model_t': 0,
}
t, epoch, reward_moving_average = 0, 0, 0
set_mode('train', [program_generator, execution_engine, baseline_model])
print('train_loader has %d samples' % len(train_loader.dataset))
print('val_loader has %d samples' % len(val_loader.dataset))
while t < args.num_iterations:
epoch += 1
print('Starting epoch %d' % epoch)
for batch in train_loader:
t += 1
questions, _, feats, answers, programs, _ = batch
questions_var = Variable(questions.cuda())
feats_var = Variable(feats.cuda())
answers_var = Variable(answers.cuda())
if programs[0] is not None:
programs_var = Variable(programs.cuda())
reward = None
if args.model_type == 'PG':
# Train program generator with ground-truth programs
pg_optimizer.zero_grad()
loss = program_generator(questions_var, programs_var)
loss.backward()
pg_optimizer.step()
elif args.model_type == 'EE':
# Train execution engine with ground-truth programs
ee_optimizer.zero_grad()
scores = execution_engine(feats_var, programs_var)
loss = loss_fn(scores, answers_var)
loss.backward()
ee_optimizer.step()
elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
baseline_optimizer.zero_grad()
baseline_model.zero_grad()
scores = baseline_model(questions_var, feats_var)
loss = loss_fn(scores, answers_var)
loss.backward()
baseline_optimizer.step()
elif args.model_type == 'PG+EE':
programs_pred = program_generator.reinforce_sample(questions_var)
scores = execution_engine(feats_var, programs_pred)
loss = loss_fn(scores, answers_var)
_, preds = scores.data.cpu().max(1)
raw_reward = (preds == answers).float()
reward_moving_average *= args.reward_decay
reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean()
centered_reward = raw_reward - reward_moving_average
if args.train_execution_engine == 1:
ee_optimizer.zero_grad()
loss.backward()
ee_optimizer.step()
if args.train_program_generator == 1:
pg_optimizer.zero_grad()
program_generator.reinforce_backward(centered_reward.cuda())
pg_optimizer.step()
if t % args.record_loss_every == 0:
print(t, loss.data[0])
stats['train_losses'].append(loss.data[0])
stats['train_losses_ts'].append(t)
if reward is not None:
stats['train_rewards'].append(reward)
if t % args.checkpoint_every == 0:
print('Checking training accuracy ... ')
train_acc = check_accuracy(args, program_generator, execution_engine,
baseline_model, train_loader)
print('train accuracy is', train_acc)
print('Checking validation accuracy ...')
val_acc = check_accuracy(args, program_generator, execution_engine,
baseline_model, val_loader)
print('val accuracy is ', val_acc)
stats['train_accs'].append(train_acc)
stats['val_accs'].append(val_acc)
stats['val_accs_ts'].append(t)
if val_acc > stats['best_val_acc']:
stats['best_val_acc'] = val_acc
stats['model_t'] = t
best_pg_state = get_state(program_generator)
best_ee_state = get_state(execution_engine)
best_baseline_state = get_state(baseline_model)
checkpoint = {
'args': args.__dict__,
'program_generator_kwargs': pg_kwargs,
'program_generator_state': best_pg_state,
'execution_engine_kwargs': ee_kwargs,
'execution_engine_state': best_ee_state,
'baseline_kwargs': baseline_kwargs,
'baseline_state': best_baseline_state,
'baseline_type': baseline_type,
'vocab': vocab
}
for k, v in stats.items():
checkpoint[k] = v
print('Saving checkpoint to %s' % args.checkpoint_path)
torch.save(checkpoint, args.checkpoint_path)
del checkpoint['program_generator_state']
del checkpoint['execution_engine_state']
del checkpoint['baseline_state']
with open(args.checkpoint_path + '.json', 'w') as f:
json.dump(checkpoint, f)
if t == args.num_iterations:
break