def eval_split()

in captioning/utils/eval_utils_for_coco_caption.py [0:0]


def eval_split(model, crit, loader, task='caption', eval_kwargs={}):
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 0)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)
    sample_n = eval_kwargs.get('sample_n', 1)
    remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
    os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
    device = eval_kwargs.get('device', 'cuda')

    # assert task
    assert task in ['caption', 'trace', 'both', 'show']

    # Make sure in the evaluation mode
    model.eval()

    loader.reset_iterator(split)

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    n_predictions = [] # when sample_n > 1
    grounding_quality_loss = [] # for donstream task 2
    while True:
        data = loader.get_batch(split)
        # print('In eval_utils:', split)###zihang
        print(num_images)
        n = n + len(data['infos'])

        tmp = [data['fc_feats'], data['att_feats'], data['trace_feats'], data['box_feats'], data['labels'],
               data['masks'], data['att_masks'], data['trace_masks'], data['show_labels'], data['show_trace_feats'], data['show_trace_masks'], data['show_masks'], data['show_gate_labels']]
        tmp = [_.to(device) if _ is not None else _ for _ in tmp]
        fc_feats, att_feats, trace_feats, box_feats, labels, masks, att_masks, \
            trace_masks, show_labels, show_trace_feats, show_trace_masks, show_masks, show_gate_labels = tmp
        if labels is not None and verbose_loss:
            # forward the model to get loss
            with torch.no_grad():
                if task == 'caption':
                    loss = 0
                    # loss = crit(model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1], att_masks, trace_masks, task=task), labels[..., 1:], masks[..., 1:]).item()
                elif task == 'show':
                    loss = crit(
                        model(fc_feats, att_feats, show_trace_feats, box_feats, show_labels[..., :-1], att_masks,
                              show_trace_masks, show_gate_labels, task=task), show_labels[..., 1:], show_masks[..., 1:]).item()
                elif task == 'both':
                    loss = crit(
                        model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1], att_masks, trace_masks,
                              task=task)[0], labels[..., 1:], masks[..., 1:]).item()
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1



        test_grounding_quality = True
        test_baseline = False
        # forward the model to also get generated samples for each image
        with torch.no_grad():
            tmp_eval_kwargs = eval_kwargs.copy()
            tmp_eval_kwargs.update({'sample_n': 1})
            ### repeat att feats
            if test_grounding_quality:
                fc_feats, att_feats, att_masks, box_feats = utils_models.repeat_tensors(5,
                                                                   [fc_feats, att_feats, att_masks, box_feats]
                                                                   )
            #############################
            if task == 'both':
                seq, seq_logprobs, _ = model(fc_feats, att_feats, show_trace_feats[:att_feats.shape[0]], box_feats, att_masks, show_trace_masks[:att_feats.shape[0]],
                                          show_gate_labels[:att_feats.shape[0]], task, opt=tmp_eval_kwargs, mode='sample')
                # use gt-truth to get prediction

                _, trace_output = model(fc_feats, att_feats, show_trace_feats[:,:17],
                                                                           box_feats,
                                                                           show_labels[..., :-1].squeeze(1),
                                                                           att_masks, show_masks.squeeze(1)[:,:17],
                                                                           task='both')
                # ### debug try using trace to give trace output
                # print(show_trace_feats.shape, show_labels.shape, show_masks.shape)
                # trace_output = model(fc_feats, att_feats, show_trace_feats[:, :17], box_feats,
                #                       show_labels[..., :-1].squeeze(1),
                #                       att_masks, show_masks.squeeze(1), task='trace')
            else:
                if test_baseline is True and task == 'caption':
                    seq, seq_logprobs, word_box_attn = model(fc_feats, att_feats, show_trace_feats[:att_feats.shape[0]], box_feats,
                                              att_masks, show_trace_masks[:att_feats.shape[0]],
                                              show_gate_labels[:att_feats.shape[0]], task, opt=tmp_eval_kwargs,
                                              mode='sample')
                else:
                    seq, seq_logprobs = model(fc_feats, att_feats, show_trace_feats[:att_feats.shape[0]], box_feats,
                                          att_masks, show_trace_masks[:att_feats.shape[0]], show_gate_labels[:att_feats.shape[0]], task, opt=tmp_eval_kwargs, mode='sample')
            seq = seq.data
            entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
            perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
            ### log which caption has no bounding box
            ids_no_box = (show_trace_feats[:, 0, 4] == 1).float()


        # only focus on the gt-words
        if test_grounding_quality:
            batch_grounding_loss = []
            if test_baseline:
                word_box_attn = torch.argmax(word_box_attn, dim=-1)
            # match the generated word with the show-caption
            show_labels = show_labels[:, :, 1:-1]
            for i in range(seq.shape[0]):
                # for j in range(seq.shape[1]):
                    for k in range(show_labels.shape[2]):
                        if show_trace_feats[i, k, 4] != 1 and show_labels[i, 0, k] != 0:
                        # if show_trace_feats[i, k, 4] != 1 and show_labels[i, 0, k] != 0 \
                        #         and seq[i, k] == show_labels[
                        #     i, 0, k]:  # the word match with the key word and show_labels[i,0,k] != 1
                            gt_box = show_trace_feats[i, k]  # get the grounding box
                            if test_baseline:
                                pred_box_idx = word_box_attn[i, k].long()
                                pred_box = box_feats[i, pred_box_idx]  # get the predicted box
                            else:
                                pred_box = trace_output[i, k]
                            # print(gt_box, pred_box, seq[i,j])
                            tmp_loss = torch.mean(torch.abs(gt_box[:4] - pred_box[:4]))
                            batch_grounding_loss.append(tmp_loss.item())

        ### compute the grounding quality
        # if test_grounding_quality:
        #     batch_grounding_loss = []
        #     if test_baseline:
        #         word_box_attn = torch.argmax(word_box_attn, dim=-1)
        #     # match the generated word with the show-caption
        #     for i in range(seq.shape[0]):
        #         for j in range(seq.shape[1]):
        #             for k in range(show_labels.shape[2]):
        #                 if show_trace_feats[i,k,4]!=1 and show_labels[i,0,k] != 0  \
        #                         and seq[i,j] == show_labels[i,0,k]: # the word match with the key word and show_labels[i,0,k] != 1
        #                     gt_box = show_trace_feats[i, k] # get the grounding box
        #                     if test_baseline:
        #                         pred_box_idx = word_box_attn[i,j].long()
        #                         pred_box = box_feats[i, pred_box_idx] # get the predicted box
        #                     else:
        #                         pred_box = trace_output[i, j]
        #                     # print(gt_box, pred_box, seq[i,j])
        #                     tmp_loss = torch.mean(torch.abs(gt_box[:4] - pred_box[:4]))
        #                     batch_grounding_loss.append(tmp_loss.item())


            # else:
            #     assert task == 'both'
            #     for i in range(seq.shape[0]):
            #         for j in range(seq.shape[1]):
            #             for k in range(show_labels.shape[2]):
            #                 if seq[i, j] != 0 and seq[i, j] == show_labels[i, 0, k]:  # the word match with the key word
            #                     gt_box = show_trace_feats[i, k]  # get the grounding box
            #                     pred_box = trace_output[i,j]
            #                     tmp_loss = torch.mean(torch.abs(gt_box[:4] - pred_box[:4]))
            #                     batch_grounding_loss.append(tmp_loss.item())

            grounding_quality_loss.append(np.mean(np.array(batch_grounding_loss)))
            print('Visual grounding quality running ave: ', np.mean(np.array(grounding_quality_loss)))
            seq = seq.reshape([-1, 5, 20])[:,0,:]


        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(fc_feats.shape[0]):
                print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
                print('--' * 10)
        sents = utils.decode_sequence(model.vocab, seq)
        for k, sent in enumerate(sents):
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}

            # entry to evaluate show-control-tell: seperate the 5 predictions per image
            # if ids_no_box[k]==1:
            #     continue
            # entry = {'image_id': data['infos'][k//5]['id'] + 1000000 * (k%5), 'caption': sent, 'perplexity': perplexity[k].item(),
            #          'entropy': entropy[k].item()}

            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' %(entry['image_id'], entry['caption']))

        if sample_n > 1:
            eval_split_n(model, n_predictions, [fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks, data], eval_kwargs)
        
        # ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        # print('ix1', ix1)###zihang
        if num_images != -1:
            ix1 = min(ix1, num_images)
        else:
            num_images = ix1
        # print('len:', len(predictions), n, ix1, split, num_images)  ###zihang
        for i in range(n - ix1):
            predictions.pop()

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))

        if num_images >= 0 and n >= num_images:
            break

    print('Total visual grounding quality loss:', np.mean(np.array(grounding_quality_loss)))

    lang_stats = None
    if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
        n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
    if not os.path.isdir('eval_results'):
        os.mkdir('eval_results')
    torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
    if lang_eval == 1:
        lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)

    # Switch back to training mode
    model.train()
    return loss_sum/loss_evals, predictions, lang_stats