def eval()

in training/train_vqa.py [0:0]


def eval(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank%len(args.gpus)],
        'to_cache': args.cache
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0

    while epoch < int(args.max_epochs):

        model.load_state_dict(shared_model.state_dict())
        model.eval()

        metrics = VqaMetric(
            info={'split': args.eval_split},
            metric_names=[
                'loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'
            ],
            log_json=args.output_log_path)

        if args.input_type == 'ques':
            for batch in eval_loader:
                t += 1

                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

            print(metrics.get_stat_string(mode=0))

        elif args.input_type == 'ques,image':
            done = False
            all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded()

            while done == False:
                for batch in eval_loader:
                    t += 1

                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update(
                        [loss.data[0], accuracy, ranks, 1.0 / ranks])

                print(metrics.get_stat_string(mode=0))

                if all_envs_loaded == False:
                    eval_loader.dataset._load_envs()
                    if len(eval_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1

        # checkpoint if best val accuracy
        if metrics.metrics[1][0] > best_eval_acc:
            best_eval_acc = metrics.metrics[1][0]
            if epoch % args.eval_every == 0 and args.log == True:
                metrics.dump_log()

                model_state = get_state(model)

                if args.checkpoint_path != False:
                    ad = checkpoint['args']
                else:
                    ad = args.__dict__

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_accuracy_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_accuracy:%.04f]' % best_eval_acc)