in training/train_eqa.py [0:0]
def train(rank, args, shared_nav_model, shared_ans_model):
torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))
if args.model_type == 'pacman':
model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
nav_model = NavPlannerControllerModel(**model_kwargs)
else:
exit()
model_kwargs = {'vocab': load_vocab(args.vocab_json)}
ans_model = VqaLstmCnnAttentionModel(**model_kwargs)
optim = torch.optim.SGD(
filter(lambda p: p.requires_grad, shared_nav_model.parameters()),
lr=args.learning_rate)
train_loader_kwargs = {
'questions_h5': args.train_h5,
'data_json': args.data_json,
'vocab': args.vocab_json,
'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
'map_resolution': args.map_resolution,
'batch_size': 1,
'input_type': args.model_type,
'num_frames': 5,
'split': 'train',
'max_threads_per_gpu': args.max_threads_per_gpu,
'gpu_id': args.gpus[rank % len(args.gpus)],
'to_cache': args.cache,
'max_controller_actions': args.max_controller_actions,
'max_actions': args.max_actions
}
args.output_nav_log_path = os.path.join(args.log_dir,
'nav_train_' + str(rank) + '.json')
args.output_ans_log_path = os.path.join(args.log_dir,
'ans_train_' + str(rank) + '.json')
nav_model.load_state_dict(shared_nav_model.state_dict())
nav_model.cuda()
ans_model.load_state_dict(shared_ans_model.state_dict())
ans_model.eval()
ans_model.cuda()
nav_metrics = NavMetric(
info={'split': 'train',
'thread': rank},
metric_names=[
'planner_loss', 'controller_loss', 'reward', 'episode_length'
],
log_json=args.output_nav_log_path)
vqa_metrics = VqaMetric(
info={'split': 'train',
'thread': rank},
metric_names=['accuracy', 'mean_rank', 'mean_reciprocal_rank'],
log_json=args.output_ans_log_path)
train_loader = EqaDataLoader(**train_loader_kwargs)
print('train_loader has %d samples' % len(train_loader.dataset))
t, epoch = 0, 0
p_losses, c_losses, reward_list, episode_length_list = [], [], [], []
nav_metrics.update([10.0, 10.0, 0, 100])
mult = 0.1
while epoch < int(args.max_epochs):
if 'pacman' in args.model_type:
planner_lossFn = MaskedNLLCriterion().cuda()
controller_lossFn = MaskedNLLCriterion().cuda()
done = False
all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
while done == False:
for batch in train_loader:
nav_model.load_state_dict(shared_nav_model.state_dict())
nav_model.eval()
nav_model.cuda()
idx, question, answer, actions, action_length = batch
metrics_slug = {}
h3d = train_loader.dataset.episode_house
# evaluate at multiple initializations
# for i in [10, 30, 50]:
t += 1
question_var = Variable(question.cuda())
controller_step = False
planner_hidden = nav_model.planner_nav_rnn.init_hidden(1)
# forward through planner till spawn
(
planner_actions_in, planner_img_feats,
controller_step, controller_action_in,
controller_img_feat, init_pos,
controller_action_counter
) = train_loader.dataset.get_hierarchical_features_till_spawn(
actions[0, :action_length[0] + 1].numpy(), max(3, int(mult * action_length[0])), args.max_controller_actions
)
planner_actions_in_var = Variable(
planner_actions_in.cuda())
planner_img_feats_var = Variable(planner_img_feats.cuda())
for step in range(planner_actions_in.size(0)):
planner_scores, planner_hidden = nav_model.planner_step(
question_var, planner_img_feats_var[step].view(
1, 1, 3200), planner_actions_in_var[step].view(
1, 1), planner_hidden)
if controller_step == True:
controller_img_feat_var = Variable(
controller_img_feat.cuda())
controller_action_in_var = Variable(
torch.LongTensor(1, 1).fill_(
int(controller_action_in)).cuda())
controller_scores = nav_model.controller_step(
controller_img_feat_var.view(1, 1, 3200),
controller_action_in_var.view(1, 1),
planner_hidden[0])
prob = F.softmax(controller_scores, dim=1)
controller_action = int(
prob.max(1)[1].data.cpu().numpy()[0])
if controller_action == 1:
controller_step = True
else:
controller_step = False
action = int(controller_action_in)
action_in = torch.LongTensor(
1, 1).fill_(action + 1).cuda()
else:
prob = F.softmax(planner_scores, dim=1)
action = int(prob.max(1)[1].data.cpu().numpy()[0])
action_in = torch.LongTensor(
1, 1).fill_(action + 1).cuda()
h3d.env.reset(
x=init_pos[0], y=init_pos[2], yaw=init_pos[3])
init_dist_to_target = h3d.get_dist_to_target(
h3d.env.cam.pos)
if init_dist_to_target < 0: # unreachable
# invalids.append([idx[0], i])
continue
episode_length = 0
episode_done = True
controller_action_counter = 0
dists_to_target, pos_queue = [init_dist_to_target], [
init_pos
]
rewards, planner_actions, planner_log_probs, controller_actions, controller_log_probs = [], [], [], [], []
if action != 3:
# take the first step
img, rwd, episode_done = h3d.step(action, step_reward=True)
img = torch.from_numpy(img.transpose(
2, 0, 1)).float() / 255.0
img_feat_var = train_loader.dataset.cnn(
Variable(img.view(1, 3, 224, 224).cuda())).view(
1, 1, 3200)
for step in range(args.max_episode_length):
episode_length += 1
if controller_step == False:
planner_scores, planner_hidden = nav_model.planner_step(
question_var, img_feat_var,
Variable(action_in), planner_hidden)
planner_prob = F.softmax(planner_scores, dim=1)
planner_log_prob = F.log_softmax(
planner_scores, dim=1)
action = planner_prob.multinomial().data
planner_log_prob = planner_log_prob.gather(
1, Variable(action))
planner_log_probs.append(
planner_log_prob.cpu())
action = int(action.cpu().numpy()[0, 0])
planner_actions.append(action)
img, rwd, episode_done = h3d.step(action, step_reward=True)
episode_done = episode_done or episode_length >= args.max_episode_length
rewards.append(rwd)
img = torch.from_numpy(img.transpose(
2, 0, 1)).float() / 255.0
img_feat_var = train_loader.dataset.cnn(
Variable(img.view(1, 3, 224, 224)
.cuda())).view(1, 1, 3200)
dists_to_target.append(
h3d.get_dist_to_target(h3d.env.cam.pos))
pos_queue.append([
h3d.env.cam.pos.x, h3d.env.cam.pos.y,
h3d.env.cam.pos.z, h3d.env.cam.yaw
])
if episode_done == True:
break
# query controller to continue or not
controller_action_in = Variable(
torch.LongTensor(1, 1).fill_(action).cuda())
controller_scores = nav_model.controller_step(
img_feat_var, controller_action_in,
planner_hidden[0])
controller_prob = F.softmax(
controller_scores, dim=1)
controller_log_prob = F.log_softmax(
controller_scores, dim=1)
controller_action = controller_prob.multinomial(
).data
if int(controller_action[0]
) == 1 and controller_action_counter < 4:
controller_action_counter += 1
controller_step = True
else:
controller_action_counter = 0
controller_step = False
controller_action.fill_(0)
controller_log_prob = controller_log_prob.gather(
1, Variable(controller_action))
controller_log_probs.append(
controller_log_prob.cpu())
controller_action = int(
controller_action.cpu().numpy()[0, 0])
controller_actions.append(controller_action)
action_in = torch.LongTensor(
1, 1).fill_(action + 1).cuda()
# run answerer here
ans_acc = [0]
if action == 3:
if len(pos_queue) < 5:
pos_queue = train_loader.dataset.episode_pos_queue[len(
pos_queue) - 5:] + pos_queue
images = train_loader.dataset.get_frames(
h3d, pos_queue[-5:], preprocess=True)
images_var = Variable(
torch.from_numpy(images).cuda()).view(
1, 5, 3, 224, 224)
scores, att_probs = ans_model(images_var, question_var)
ans_acc, ans_rank = vqa_metrics.compute_ranks(
scores.data.cpu(), answer)
vqa_metrics.update([ans_acc, ans_rank, 1.0 / ans_rank])
rewards.append(h3d.success_reward * ans_acc[0])
R = torch.zeros(1, 1)
planner_loss = 0
controller_loss = 0
planner_rev_idx = -1
for i in reversed(range(len(rewards))):
R = 0.99 * R + rewards[i]
advantage = R - nav_metrics.metrics[2][1]
if i < len(controller_actions):
controller_loss = controller_loss - controller_log_probs[i] * Variable(
advantage)
if controller_actions[i] == 0 and planner_rev_idx + len(planner_log_probs) >= 0:
planner_loss = planner_loss - planner_log_probs[planner_rev_idx] * Variable(
advantage)
planner_rev_idx -= 1
elif planner_rev_idx + len(planner_log_probs) >= 0:
planner_loss = planner_loss - planner_log_probs[planner_rev_idx] * Variable(
advantage)
planner_rev_idx -= 1
controller_loss /= max(1, len(controller_log_probs))
planner_loss /= max(1, len(planner_log_probs))
optim.zero_grad()
if isinstance(planner_loss, float) == False and isinstance(
controller_loss, float) == False:
p_losses.append(planner_loss.data[0, 0])
c_losses.append(controller_loss.data[0, 0])
reward_list.append(np.sum(rewards))
episode_length_list.append(episode_length)
(planner_loss + controller_loss).backward()
ensure_shared_grads(nav_model.cpu(), shared_nav_model)
optim.step()
if len(reward_list) > 50:
nav_metrics.update([
p_losses, c_losses, reward_list,
episode_length_list
])
print(nav_metrics.get_stat_string())
if args.log == True:
nav_metrics.dump_log()
if nav_metrics.metrics[2][1] > 0.35:
mult = min(mult + 0.1, 1.0)
p_losses, c_losses, reward_list, episode_length_list = [], [], [], []
if all_envs_loaded == False:
train_loader.dataset._load_envs(in_order=True)
if len(train_loader.dataset.pruned_env_set) == 0:
done = True
if args.cache == False:
train_loader.dataset._load_envs(
start_idx=0, in_order=True)
else:
done = True
epoch += 1