def train()

in training/train_eqa.py [0:0]


def train(rank, args, shared_nav_model, shared_ans_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    ans_model = VqaLstmCnnAttentionModel(**model_kwargs)

    optim = torch.optim.SGD(
        filter(lambda p: p.requires_grad, shared_nav_model.parameters()),
        lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.cache,
        'max_controller_actions': args.max_controller_actions,
        'max_actions': args.max_actions
    }

    args.output_nav_log_path = os.path.join(args.log_dir,
                                            'nav_train_' + str(rank) + '.json')
    args.output_ans_log_path = os.path.join(args.log_dir,
                                            'ans_train_' + str(rank) + '.json')

    nav_model.load_state_dict(shared_nav_model.state_dict())
    nav_model.cuda()

    ans_model.load_state_dict(shared_ans_model.state_dict())
    ans_model.eval()
    ans_model.cuda()

    nav_metrics = NavMetric(
        info={'split': 'train',
              'thread': rank},
        metric_names=[
            'planner_loss', 'controller_loss', 'reward', 'episode_length'
        ],
        log_json=args.output_nav_log_path)

    vqa_metrics = VqaMetric(
        info={'split': 'train',
              'thread': rank},
        metric_names=['accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_ans_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0
    p_losses, c_losses, reward_list, episode_length_list = [], [], [], []

    nav_metrics.update([10.0, 10.0, 0, 100])

    mult = 0.1

    while epoch < int(args.max_epochs):

        if 'pacman' in args.model_type:

            planner_lossFn = MaskedNLLCriterion().cuda()
            controller_lossFn = MaskedNLLCriterion().cuda()

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    nav_model.load_state_dict(shared_nav_model.state_dict())
                    nav_model.eval()
                    nav_model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = train_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    # for i in [10, 30, 50]:

                    t += 1

                    question_var = Variable(question.cuda())

                    controller_step = False
                    planner_hidden = nav_model.planner_nav_rnn.init_hidden(1)

                    # forward through planner till spawn
                    (
                        planner_actions_in, planner_img_feats,
                        controller_step, controller_action_in,
                        controller_img_feat, init_pos,
                        controller_action_counter
                    ) = train_loader.dataset.get_hierarchical_features_till_spawn(
                        actions[0, :action_length[0] + 1].numpy(), max(3, int(mult * action_length[0])), args.max_controller_actions
                    )

                    planner_actions_in_var = Variable(
                        planner_actions_in.cuda())
                    planner_img_feats_var = Variable(planner_img_feats.cuda())

                    for step in range(planner_actions_in.size(0)):

                        planner_scores, planner_hidden = nav_model.planner_step(
                            question_var, planner_img_feats_var[step].view(
                                1, 1, 3200), planner_actions_in_var[step].view(
                                    1, 1), planner_hidden)

                    if controller_step == True:

                        controller_img_feat_var = Variable(
                            controller_img_feat.cuda())
                        controller_action_in_var = Variable(
                            torch.LongTensor(1, 1).fill_(
                                int(controller_action_in)).cuda())

                        controller_scores = nav_model.controller_step(
                            controller_img_feat_var.view(1, 1, 3200),
                            controller_action_in_var.view(1, 1),
                            planner_hidden[0])

                        prob = F.softmax(controller_scores, dim=1)
                        controller_action = int(
                            prob.max(1)[1].data.cpu().numpy()[0])

                        if controller_action == 1:
                            controller_step = True
                        else:
                            controller_step = False

                        action = int(controller_action_in)
                        action_in = torch.LongTensor(
                            1, 1).fill_(action + 1).cuda()

                    else:

                        prob = F.softmax(planner_scores, dim=1)
                        action = int(prob.max(1)[1].data.cpu().numpy()[0])

                        action_in = torch.LongTensor(
                            1, 1).fill_(action + 1).cuda()

                    h3d.env.reset(
                        x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                    init_dist_to_target = h3d.get_dist_to_target(
                        h3d.env.cam.pos)
                    if init_dist_to_target < 0:  # unreachable
                        # invalids.append([idx[0], i])
                        continue

                    episode_length = 0
                    episode_done = True
                    controller_action_counter = 0

                    dists_to_target, pos_queue = [init_dist_to_target], [
                        init_pos
                    ]

                    rewards, planner_actions, planner_log_probs, controller_actions, controller_log_probs = [], [], [], [], []

                    if action != 3:

                        # take the first step
                        img, rwd, episode_done = h3d.step(action, step_reward=True)
                        img = torch.from_numpy(img.transpose(
                            2, 0, 1)).float() / 255.0
                        img_feat_var = train_loader.dataset.cnn(
                            Variable(img.view(1, 3, 224, 224).cuda())).view(
                                1, 1, 3200)

                        for step in range(args.max_episode_length):

                            episode_length += 1

                            if controller_step == False:
                                planner_scores, planner_hidden = nav_model.planner_step(
                                    question_var, img_feat_var,
                                    Variable(action_in), planner_hidden)

                                planner_prob = F.softmax(planner_scores, dim=1)
                                planner_log_prob = F.log_softmax(
                                    planner_scores, dim=1)

                                action = planner_prob.multinomial().data
                                planner_log_prob = planner_log_prob.gather(
                                    1, Variable(action))

                                planner_log_probs.append(
                                    planner_log_prob.cpu())

                                action = int(action.cpu().numpy()[0, 0])
                                planner_actions.append(action)

                            img, rwd, episode_done = h3d.step(action, step_reward=True)

                            episode_done = episode_done or episode_length >= args.max_episode_length

                            rewards.append(rwd)

                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = train_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224, 224)
                                         .cuda())).view(1, 1, 3200)

                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done == True:
                                break

                            # query controller to continue or not
                            controller_action_in = Variable(
                                torch.LongTensor(1, 1).fill_(action).cuda())
                            controller_scores = nav_model.controller_step(
                                img_feat_var, controller_action_in,
                                planner_hidden[0])

                            controller_prob = F.softmax(
                                controller_scores, dim=1)
                            controller_log_prob = F.log_softmax(
                                controller_scores, dim=1)

                            controller_action = controller_prob.multinomial(
                            ).data

                            if int(controller_action[0]
                                   ) == 1 and controller_action_counter < 4:
                                controller_action_counter += 1
                                controller_step = True
                            else:
                                controller_action_counter = 0
                                controller_step = False
                                controller_action.fill_(0)

                            controller_log_prob = controller_log_prob.gather(
                                1, Variable(controller_action))
                            controller_log_probs.append(
                                controller_log_prob.cpu())

                            controller_action = int(
                                controller_action.cpu().numpy()[0, 0])
                            controller_actions.append(controller_action)
                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                    # run answerer here
                    ans_acc = [0]
                    if action == 3:
                        if len(pos_queue) < 5:
                            pos_queue = train_loader.dataset.episode_pos_queue[len(
                                pos_queue) - 5:] + pos_queue
                        images = train_loader.dataset.get_frames(
                            h3d, pos_queue[-5:], preprocess=True)
                        images_var = Variable(
                            torch.from_numpy(images).cuda()).view(
                                1, 5, 3, 224, 224)
                        scores, att_probs = ans_model(images_var, question_var)
                        ans_acc, ans_rank = vqa_metrics.compute_ranks(
                            scores.data.cpu(), answer)
                        vqa_metrics.update([ans_acc, ans_rank, 1.0 / ans_rank])

                    rewards.append(h3d.success_reward * ans_acc[0])

                    R = torch.zeros(1, 1)

                    planner_loss = 0
                    controller_loss = 0

                    planner_rev_idx = -1
                    for i in reversed(range(len(rewards))):
                        R = 0.99 * R + rewards[i]
                        advantage = R - nav_metrics.metrics[2][1]

                        if i < len(controller_actions):
                            controller_loss = controller_loss - controller_log_probs[i] * Variable(
                                advantage)

                            if controller_actions[i] == 0 and planner_rev_idx + len(planner_log_probs) >= 0:
                                planner_loss = planner_loss - planner_log_probs[planner_rev_idx] * Variable(
                                    advantage)
                                planner_rev_idx -= 1

                        elif planner_rev_idx + len(planner_log_probs) >= 0:

                            planner_loss = planner_loss - planner_log_probs[planner_rev_idx] * Variable(
                                advantage)
                            planner_rev_idx -= 1

                    controller_loss /= max(1, len(controller_log_probs))
                    planner_loss /= max(1, len(planner_log_probs))

                    optim.zero_grad()

                    if isinstance(planner_loss, float) == False and isinstance(
                            controller_loss, float) == False:
                        p_losses.append(planner_loss.data[0, 0])
                        c_losses.append(controller_loss.data[0, 0])
                        reward_list.append(np.sum(rewards))
                        episode_length_list.append(episode_length)

                        (planner_loss + controller_loss).backward()

                        ensure_shared_grads(nav_model.cpu(), shared_nav_model)
                        optim.step()

                    if len(reward_list) > 50:

                        nav_metrics.update([
                            p_losses, c_losses, reward_list,
                            episode_length_list
                        ])

                        print(nav_metrics.get_stat_string())
                        if args.log == True:
                            nav_metrics.dump_log()

                        if nav_metrics.metrics[2][1] > 0.35:
                            mult = min(mult + 0.1, 1.0)

                        p_losses, c_losses, reward_list, episode_length_list = [], [], [], []

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        epoch += 1