in training/train_nav.py [0:0]
def train(rank, args, shared_model):
torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))
if args.model_type == 'cnn':
model_kwargs = {}
model = NavCnnModel(**model_kwargs)
elif args.model_type == 'cnn+q':
model_kwargs = {
'question_input': True,
'question_vocab': load_vocab(args.vocab_json)
}
model = NavCnnModel(**model_kwargs)
elif args.model_type == 'lstm':
model_kwargs = {}
model = NavCnnRnnModel(**model_kwargs)
elif args.model_type == 'lstm-mult+q':
model_kwargs = {
'question_input': True,
'question_vocab': load_vocab(args.vocab_json)
}
model = NavCnnRnnMultModel(**model_kwargs)
elif args.model_type == 'lstm+q':
model_kwargs = {
'question_input': True,
'question_vocab': load_vocab(args.vocab_json)
}
model = NavCnnRnnModel(**model_kwargs)
elif args.model_type == 'pacman':
model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
model = NavPlannerControllerModel(**model_kwargs)
else:
exit()
lossFn = torch.nn.CrossEntropyLoss().cuda()
optim = torch.optim.Adamax(
filter(lambda p: p.requires_grad, shared_model.parameters()),
lr=args.learning_rate)
train_loader_kwargs = {
'questions_h5': args.train_h5,
'data_json': args.data_json,
'vocab': args.vocab_json,
'batch_size': args.batch_size,
'input_type': args.model_type,
'num_frames': 5,
'map_resolution': args.map_resolution,
'split': 'train',
'max_threads_per_gpu': args.max_threads_per_gpu,
'gpu_id': args.gpus[rank % len(args.gpus)],
'to_cache': args.cache,
'overfit': args.overfit,
'max_controller_actions': args.max_controller_actions,
'max_actions': args.max_actions
}
args.output_log_path = os.path.join(args.log_dir,
'train_' + str(rank) + '.json')
if 'pacman' in args.model_type:
metrics = NavMetric(
info={'split': 'train',
'thread': rank},
metric_names=['planner_loss', 'controller_loss'],
log_json=args.output_log_path)
else:
metrics = NavMetric(
info={'split': 'train',
'thread': rank},
metric_names=['loss'],
log_json=args.output_log_path)
train_loader = EqaDataLoader(**train_loader_kwargs)
print('train_loader has %d samples' % len(train_loader.dataset))
logging.info('TRAIN: train loader has {} samples'.format(len(train_loader.dataset)))
t, epoch = 0, 0
while epoch < int(args.max_epochs):
if 'cnn' in args.model_type:
done = False
all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
while done == False:
for batch in train_loader:
t += 1
model.load_state_dict(shared_model.state_dict())
model.train()
model.cuda()
idx, questions, _, img_feats, _, actions_out, _ = batch
img_feats_var = Variable(img_feats.cuda())
if '+q' in args.model_type:
questions_var = Variable(questions.cuda())
actions_out_var = Variable(actions_out.cuda())
if '+q' in args.model_type:
scores = model(img_feats_var, questions_var)
else:
scores = model(img_feats_var)
loss = lossFn(scores, actions_out_var)
# zero grad
optim.zero_grad()
# update metrics
metrics.update([loss.data[0]])
# backprop and update
loss.backward()
ensure_shared_grads(model.cpu(), shared_model)
optim.step()
if t % args.print_every == 0:
print(metrics.get_stat_string())
logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
if args.log == True:
metrics.dump_log()
print('[CHECK][Cache:%d][Total:%d]' %
(len(train_loader.dataset.img_data_cache),
len(train_loader.dataset.env_list)))
logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))
if all_envs_loaded == False:
train_loader.dataset._load_envs(in_order=True)
if len(train_loader.dataset.pruned_env_set) == 0:
done = True
if args.cache == False:
train_loader.dataset._load_envs(
start_idx=0, in_order=True)
else:
done = True
elif 'lstm' in args.model_type:
lossFn = MaskedNLLCriterion().cuda()
done = False
all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
total_times = []
while done == False:
start_time = time.time()
for batch in train_loader:
t += 1
model.load_state_dict(shared_model.state_dict())
model.train()
model.cuda()
idx, questions, _, img_feats, actions_in, actions_out, action_lengths, masks = batch
img_feats_var = Variable(img_feats.cuda())
if '+q' in args.model_type:
questions_var = Variable(questions.cuda())
actions_in_var = Variable(actions_in.cuda())
actions_out_var = Variable(actions_out.cuda())
action_lengths = action_lengths.cuda()
masks_var = Variable(masks.cuda())
action_lengths, perm_idx = action_lengths.sort(
0, descending=True)
img_feats_var = img_feats_var[perm_idx]
if '+q' in args.model_type:
questions_var = questions_var[perm_idx]
actions_in_var = actions_in_var[perm_idx]
actions_out_var = actions_out_var[perm_idx]
masks_var = masks_var[perm_idx]
if '+q' in args.model_type:
scores, hidden = model(img_feats_var, questions_var,
actions_in_var,
action_lengths.cpu().numpy())
else:
scores, hidden = model(img_feats_var, False,
actions_in_var,
action_lengths.cpu().numpy())
#block out masks
if args.curriculum:
curriculum_length = (epoch+1)*5
for i, action_length in enumerate(action_lengths):
if action_length - curriculum_length > 0:
masks_var[i, :action_length-curriculum_length] = 0
logprob = F.log_softmax(scores, dim=1)
loss = lossFn(
logprob, actions_out_var[:, :action_lengths.max()]
.contiguous().view(-1, 1),
masks_var[:, :action_lengths.max()].contiguous().view(
-1, 1))
# zero grad
optim.zero_grad()
# update metrics
metrics.update([loss.data[0]])
logging.info("TRAIN LSTM loss: {:.6f}".format(loss.data[0]))
# backprop and update
loss.backward()
ensure_shared_grads(model.cpu(), shared_model)
optim.step()
if t % args.print_every == 0:
print(metrics.get_stat_string())
logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
if args.log == True:
metrics.dump_log()
print('[CHECK][Cache:%d][Total:%d]' %
(len(train_loader.dataset.img_data_cache),
len(train_loader.dataset.env_list)))
logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))
if all_envs_loaded == False:
train_loader.dataset._load_envs(in_order=True)
if len(train_loader.dataset.pruned_env_set) == 0:
done = True
if args.cache == False:
train_loader.dataset._load_envs(
start_idx=0, in_order=True)
else:
done = True
elif 'pacman' in args.model_type:
planner_lossFn = MaskedNLLCriterion().cuda()
controller_lossFn = MaskedNLLCriterion().cuda()
done = False
all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
while done == False:
for batch in train_loader:
t += 1
model.load_state_dict(shared_model.state_dict())
model.train()
model.cuda()
idx, questions, _, planner_img_feats, planner_actions_in, \
planner_actions_out, planner_action_lengths, planner_masks, \
controller_img_feats, controller_actions_in, planner_hidden_idx, \
controller_outs, controller_action_lengths, controller_masks = batch
questions_var = Variable(questions.cuda())
planner_img_feats_var = Variable(planner_img_feats.cuda())
planner_actions_in_var = Variable(
planner_actions_in.cuda())
planner_actions_out_var = Variable(
planner_actions_out.cuda())
planner_action_lengths = planner_action_lengths.cuda()
planner_masks_var = Variable(planner_masks.cuda())
controller_img_feats_var = Variable(
controller_img_feats.cuda())
controller_actions_in_var = Variable(
controller_actions_in.cuda())
planner_hidden_idx_var = Variable(
planner_hidden_idx.cuda())
controller_outs_var = Variable(controller_outs.cuda())
controller_action_lengths = controller_action_lengths.cuda(
)
controller_masks_var = Variable(controller_masks.cuda())
planner_action_lengths, perm_idx = planner_action_lengths.sort(
0, descending=True)
questions_var = questions_var[perm_idx]
planner_img_feats_var = planner_img_feats_var[perm_idx]
planner_actions_in_var = planner_actions_in_var[perm_idx]
planner_actions_out_var = planner_actions_out_var[perm_idx]
planner_masks_var = planner_masks_var[perm_idx]
controller_img_feats_var = controller_img_feats_var[
perm_idx]
controller_actions_in_var = controller_actions_in_var[
perm_idx]
controller_outs_var = controller_outs_var[perm_idx]
planner_hidden_idx_var = planner_hidden_idx_var[perm_idx]
controller_action_lengths = controller_action_lengths[
perm_idx]
controller_masks_var = controller_masks_var[perm_idx]
planner_scores, controller_scores, planner_hidden = model(
questions_var, planner_img_feats_var,
planner_actions_in_var,
planner_action_lengths.cpu().numpy(),
planner_hidden_idx_var, controller_img_feats_var,
controller_actions_in_var, controller_action_lengths)
planner_logprob = F.log_softmax(planner_scores, dim=1)
controller_logprob = F.log_softmax(
controller_scores, dim=1)
planner_loss = planner_lossFn(
planner_logprob,
planner_actions_out_var[:, :planner_action_lengths.max(
)].contiguous().view(-1, 1),
planner_masks_var[:, :planner_action_lengths.max()]
.contiguous().view(-1, 1))
controller_loss = controller_lossFn(
controller_logprob,
controller_outs_var[:, :controller_action_lengths.max(
)].contiguous().view(-1, 1),
controller_masks_var[:, :controller_action_lengths.max(
)].contiguous().view(-1, 1))
# zero grad
optim.zero_grad()
# update metrics
metrics.update(
[planner_loss.data[0], controller_loss.data[0]])
logging.info("TRAINING PACMAN planner-loss: {:.6f} controller-loss: {:.6f}".format(
planner_loss.data[0], controller_loss.data[0]))
# backprop and update
if args.max_controller_actions == 1:
(planner_loss).backward()
else:
(planner_loss + controller_loss).backward()
ensure_shared_grads(model.cpu(), shared_model)
optim.step()
if t % args.print_every == 0:
print(metrics.get_stat_string())
logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
if args.log == True:
metrics.dump_log()
print('[CHECK][Cache:%d][Total:%d]' %
(len(train_loader.dataset.img_data_cache),
len(train_loader.dataset.env_list)))
logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))
if all_envs_loaded == False:
train_loader.dataset._load_envs(in_order=True)
if len(train_loader.dataset.pruned_env_set) == 0:
done = True
if args.cache == False:
train_loader.dataset._load_envs(
start_idx=0, in_order=True)
else:
done = True
epoch += 1
if epoch % args.save_every == 0:
model_state = get_state(model)
optimizer_state = optim.state_dict()
aad = dict(args.__dict__)
ad = {}
for i in aad:
if i[0] != '_':
ad[i] = aad[i]
checkpoint = {'args': ad,
'state': model_state,
'epoch': epoch,
'optimizer': optimizer_state}
checkpoint_path = '%s/epoch_%d_thread_%d.pt' % (
args.checkpoint_dir, epoch, rank)
print('Saving checkpoint to %s' % checkpoint_path)
logging.info("TRAIN: Saving checkpoint to {}".format(checkpoint_path))
torch.save(checkpoint, checkpoint_path)