def eval()

in training/train_nav.py [0:0]


def eval(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'cnn':

        model_kwargs = {}
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'cnn+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'lstm':

        model_kwargs = {}
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'lstm+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'lstm-mult+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnMultModel(**model_kwargs)

    elif args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': False,
        'overfit': args.overfit,
        'max_controller_actions': args.max_controller_actions,
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))
    logging.info("EVAL: eval_loader has {} samples".format(len(eval_loader.dataset)))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0.0

    max_epochs = args.max_epochs
    if args.mode == 'eval':
        max_epochs = 1
    while epoch < int(max_epochs):

        invalids = []

        model.load_state_dict(shared_model.state_dict())
        model.eval()

        # that's a lot of numbers
        metrics = NavMetric(
            info={'split': args.eval_split,
                  'thread': rank},
            metric_names=[
                'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                'ep_len_30', 'ep_len_50'
            ],
            log_json=args.output_log_path)

        if 'cnn' in args.model_type:

            done = False

            while done == False:

                for batch in tqdm(eval_loader):

                    model.load_state_dict(shared_model.state_dict())
                    model.cuda()

                    idx, questions, _, img_feats, actions_in, actions_out, action_length = batch
                    metrics_slug = {}

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if action_length[0] + 1 - i - 5 < 0:
                            invalids.append(idx[0])
                            continue

                        ep_inds = [
                            x for x in range(action_length[0] + 1 - i - 5,
                                             action_length[0] + 1 - i)
                        ]

                        sub_img_feats = torch.index_select(
                            img_feats, 1, torch.LongTensor(ep_inds))

                        init_pos = eval_loader.dataset.episode_pos_queue[
                            ep_inds[-1]]

                        h3d = eval_loader.dataset.episode_house

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append(idx[0])
                            continue

                        sub_img_feats_var = Variable(sub_img_feats.cuda())
                        if '+q' in args.model_type:
                            questions_var = Variable(questions.cuda())

                        # sample actions till max steps or <stop>
                        # max no. of actions = 100

                        episode_length = 0
                        episode_done = True

                        dists_to_target, pos_queue, actions = [
                            init_dist_to_target
                        ], [init_pos], []

                        for step in range(args.max_episode_length):

                            episode_length += 1

                            if '+q' in args.model_type:
                                scores = model(sub_img_feats_var,
                                               questions_var)
                            else:
                                scores = model(sub_img_feats_var)

                            prob = F.softmax(scores, dim=1)

                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            actions.append(action)

                            img, _, episode_done = h3d.step(action)

                            episode_done = episode_done or episode_length >= args.max_episode_length

                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224, 224)
                                         .cuda())).view(1, 1, 3200)
                            sub_img_feats_var = torch.cat(
                                [sub_img_feats_var, img_feat_var], dim=1)
                            sub_img_feats_var = sub_img_feats_var[:, -5:, :]

                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done == True:
                                break

                        # compute stats
                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # collate and update metrics
                    metrics_list = []
                    for i in metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(metrics.metrics[
                                metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    # update metrics
                    metrics.update(metrics_list)

                print(metrics.get_stat_string(mode=0))
                print('invalids', len(invalids))
                logging.info("EVAL: metrics: {}".format(metrics.get_stat_string(mode=0)))
                logging.info("EVAL: invalids: {}".format(len(invalids)))

               # del h3d
                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        elif 'lstm' in args.model_type:

            done = False

            while done == False:

                if args.overfit:
                    metrics = NavMetric(
                        info={'split': args.eval_split,
                              'thread': rank},
                        metric_names=[
                            'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                            'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                            'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                            'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                            'ep_len_30', 'ep_len_50'
                        ],
                        log_json=args.output_log_path)

                for batch in tqdm(eval_loader):

                    model.load_state_dict(shared_model.state_dict())
                    model.cuda()

                    idx, questions, answer, _, actions_in, actions_out, action_lengths, _ = batch
                    question_var = Variable(questions.cuda())
                    metrics_slug = {}

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if action_lengths[0] - 1 - i < 0:
                            invalids.append([idx[0], i])
                            continue

                        h3d = eval_loader.dataset.episode_house

                        # forward through lstm till spawn
                        if len(eval_loader.dataset.episode_pos_queue[:-i]
                               ) > 0:
                            images = eval_loader.dataset.get_frames(
                                h3d,
                                eval_loader.dataset.episode_pos_queue[:-i],
                                preprocess=True)
                            raw_img_feats = eval_loader.dataset.cnn(
                                Variable(torch.FloatTensor(images).cuda()))

                            actions_in_pruned = actions_in[:, :
                                                           action_lengths[0] -
                                                           i]
                            actions_in_var = Variable(actions_in_pruned.cuda())
                            action_lengths_pruned = action_lengths.clone(
                            ).fill_(action_lengths[0] - i)
                            img_feats_var = raw_img_feats.view(1, -1, 3200)

                            if '+q' in args.model_type:
                                scores, hidden = model(
                                    img_feats_var, question_var,
                                    actions_in_var,
                                    action_lengths_pruned.cpu().numpy())
                            else:
                                scores, hidden = model(
                                    img_feats_var, False, actions_in_var,
                                    action_lengths_pruned.cpu().numpy())
                            try:
                                init_pos = eval_loader.dataset.episode_pos_queue[
                                    -i]
                            except:
                                invalids.append([idx[0], i])
                                continue

                            action_in = torch.LongTensor(1, 1).fill_(
                                actions_in[0,
                                           action_lengths[0] - i]).cuda()
                        else:
                            init_pos = eval_loader.dataset.episode_pos_queue[
                                -i]
                            hidden = model.nav_rnn.init_hidden(1)
                            action_in = torch.LongTensor(1, 1).fill_(0).cuda()

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        img = h3d.env.render()
                        img = torch.from_numpy(img.transpose(
                            2, 0, 1)).float() / 255.0
                        img_feat_var = eval_loader.dataset.cnn(
                            Variable(img.view(1, 3, 224, 224).cuda())).view(
                                1, 1, 3200)

                        episode_length = 0
                        episode_done = True

                        dists_to_target, pos_queue, actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        actual_pos_queue = [(h3d.env.cam.pos.x, h3d.env.cam.pos.z, h3d.env.cam.yaw)]

                        for step in range(args.max_episode_length):

                            episode_length += 1

                            if '+q' in args.model_type:
                                scores, hidden = model(
                                    img_feat_var,
                                    question_var,
                                    Variable(action_in),
                                    False,
                                    hidden=hidden,
                                    step=True)
                            else:
                                scores, hidden = model(
                                    img_feat_var,
                                    False,
                                    Variable(action_in),
                                    False,
                                    hidden=hidden,
                                    step=True)

                            prob = F.softmax(scores, dim=1)

                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            actions.append(action)

                            img, _, episode_done = h3d.step(action)

                            episode_done = episode_done or episode_length >= args.max_episode_length

                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224, 224)
                                         .cuda())).view(1, 1, 3200)

                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done == True:
                                break

                            actual_pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.z, h3d.env.cam.yaw])

                        # compute stats
                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # collate and update metrics
                    metrics_list = []
                    for i in metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(metrics.metrics[
                                metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    # update metrics
                    metrics.update(metrics_list)

                print(metrics.get_stat_string(mode=0))
                print('invalids', len(invalids))
                logging.info("EVAL: init_steps: {} metrics: {}".format(i, metrics.get_stat_string(mode=0)))
                logging.info("EVAL: init_steps: {} invalids: {}".format(i, len(invalids)))

                # del h3d
                eval_loader.dataset._load_envs()
                print("eval_loader pruned_env_set len: {}".format(len(eval_loader.dataset.pruned_env_set)))
                logging.info("eval_loader pruned_env_set len: {}".format(len(eval_loader.dataset.pruned_env_set)))
                assert len(eval_loader.dataset.pruned_env_set) > 0
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        elif 'pacman' in args.model_type:

            done = False

            while done == False:
                if args.overfit:
                    metrics = NavMetric(
                        info={'split': args.eval_split,
                              'thread': rank},
                        metric_names=[
                            'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                            'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                            'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                            'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                            'ep_len_30', 'ep_len_50'
                        ],
                        log_json=args.output_log_path)

                for batch in tqdm(eval_loader):

                    model.load_state_dict(shared_model.state_dict())
                    model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = eval_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if i > action_length[0]:
                            invalids.append([idx[0], i])
                            continue

                        question_var = Variable(question.cuda())

                        controller_step = False
                        planner_hidden = model.planner_nav_rnn.init_hidden(1)

                        # get hierarchical action history
                        (
                            planner_actions_in, planner_img_feats,
                            controller_step, controller_action_in,
                            controller_img_feats, init_pos,
                            controller_action_counter
                        ) = eval_loader.dataset.get_hierarchical_features_till_spawn(
                            actions[0, :action_length[0] + 1].numpy(), i, args.max_controller_actions
                        )

                        planner_actions_in_var = Variable(
                            planner_actions_in.cuda())
                        planner_img_feats_var = Variable(
                            planner_img_feats.cuda())

                        # forward planner till spawn to update hidden state
                        for step in range(planner_actions_in.size(0)):

                            planner_scores, planner_hidden = model.planner_step(
                                question_var, planner_img_feats_var[step]
                                .unsqueeze(0).unsqueeze(0),
                                planner_actions_in_var[step].view(1, 1),
                                planner_hidden
                            )

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        dists_to_target, pos_queue, pred_actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        planner_actions, controller_actions = [], []

                        episode_length = 0
                        if args.max_controller_actions > 1:
                            controller_action_counter = controller_action_counter % args.max_controller_actions
                            controller_action_counter = max(controller_action_counter - 1, 0)
                        else:
                            controller_action_counter = 0

                        first_step = True
                        first_step_is_controller = controller_step
                        planner_step = True
                        action = int(controller_action_in)

                        for step in range(args.max_episode_length):
                            if not first_step:
                                img = torch.from_numpy(img.transpose(
                                    2, 0, 1)).float() / 255.0
                                img_feat_var = eval_loader.dataset.cnn(
                                    Variable(img.view(1, 3, 224,
                                                      224).cuda())).view(
                                                          1, 1, 3200)
                            else:
                                img_feat_var = Variable(controller_img_feats.cuda()).view(1, 1, 3200)

                            if not first_step or first_step_is_controller:
                                # query controller to continue or not
                                controller_action_in = Variable(
                                    torch.LongTensor(1, 1).fill_(action).cuda())
                                controller_scores = model.controller_step(
                                    img_feat_var, controller_action_in,
                                    planner_hidden[0])

                                prob = F.softmax(controller_scores, dim=1)
                                controller_action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])

                                if controller_action == 1 and controller_action_counter < args.max_controller_actions - 1:
                                    controller_action_counter += 1
                                    planner_step = False
                                else:
                                    controller_action_counter = 0
                                    planner_step = True
                                    controller_action = 0

                                controller_actions.append(controller_action)
                                first_step = False

                            if planner_step:
                                if not first_step:
                                    action_in = torch.LongTensor(
                                        1, 1).fill_(action + 1).cuda()
                                    planner_scores, planner_hidden = model.planner_step(
                                        question_var, img_feat_var,
                                        Variable(action_in), planner_hidden)

                                prob = F.softmax(planner_scores, dim=1)
                                action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])
                                planner_actions.append(action)

                            episode_done = action == 3 or episode_length >= args.max_episode_length

                            episode_length += 1
                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done:
                                break

                            img, _, _ = h3d.step(action)
                            first_step = False

                        # compute stats
                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # collate and update metrics
                    metrics_list = []
                    for i in metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(metrics.metrics[
                                metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    # update metrics
                    metrics.update(metrics_list)

                try:
                    print(metrics.get_stat_string(mode=0))
                    logging.info("EVAL: metrics: {}".format(metrics.get_stat_string(mode=0)))
                except:
                    pass

                print('epoch', epoch)
                print('invalids', len(invalids))
                logging.info("EVAL: epoch {}".format(epoch))
                logging.info("EVAL: invalids {}".format(invalids))

                # del h3d
                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        epoch += 1

        # checkpoint if best val loss
        if metrics.metrics[8][0] > best_eval_acc:  # d_D_50
            best_eval_acc = metrics.metrics[8][0]
            if epoch % args.eval_every == 0 and args.log == True:
                metrics.dump_log()

                model_state = get_state(model)

                aad = dict(args.__dict__)
                ad = {}
                for i in aad:
                    if i[0] != '_':
                        ad[i] = aad[i]

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_d_D_50_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                logging.info("EVAL: Saving checkpoint to {}".format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_d_D_50:%.04f]' % best_eval_acc)
        logging.info("EVAL: [best_eval_d_D_50:{:.04f}]".format(best_eval_acc))

        eval_loader.dataset._load_envs(start_idx=0, in_order=True)