def eval_split()

in captioning/utils/eval_utils_joint.py [0:0]


def eval_split(model, crit, loader, task='caption', eval_kwargs={}):
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 0)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)
    sample_n = eval_kwargs.get('sample_n', 1)
    remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
    os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
    device = eval_kwargs.get('device', 'cuda')

    # assert task
    assert task in ['caption', 'trace', 'both']

    # Make sure in the evaluation mode
    model.eval()

    loader.reset_iterator(split)

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    n_predictions = [] # when sample_n > 1
    trace_cost = []
    while True:
        data = loader.get_batch(split)
        n = n + len(data['infos'])

        tmp = [data['fc_feats'], data['att_feats'], data['trace_feats'], data['box_feats'], data['labels'], data['masks'], data['att_masks'], data['trace_masks']]
        tmp = [_.to(device) if _ is not None else _ for _ in tmp]
        fc_feats, att_feats, trace_feats, box_feats, labels, masks, att_masks, trace_masks = tmp
        if labels is not None and verbose_loss:
            # forward the model to get loss
            with torch.no_grad():
                if task == 'caption':
                    loss = crit(model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1], att_masks, trace_masks, task=task), labels[..., 1:], masks[..., 1:]).item()
                elif task == 'both':
                    loss = crit(
                        model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1], att_masks, trace_masks,
                              task=task)[0], labels[..., 1:], masks[..., 1:]).item()
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1

        # forward the model to also get generated samples for each image
        with torch.no_grad():
            tmp_eval_kwargs = eval_kwargs.copy()
            tmp_eval_kwargs.update({'sample_n': 1})
            if task == 'both':
                seq, seq_logprobs, trace_predicted = model(fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks,
                                          task=task, opt=tmp_eval_kwargs, mode='sample')
            else:
                try:
                    seq, seq_logprobs = model(fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks, task=task, opt=tmp_eval_kwargs, mode='sample')
                except:
                    print('evaluation meet error')
                    continue
            seq = seq.data
            entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
            perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)

        if task == 'both':
            ### compute the loss for trace
            for k in range(trace_predicted.shape[0]):
                tmp_gt_length = trace_masks[k].sum().long()
                tmp_gt_trace = trace_feats[k, :tmp_gt_length]
                tmp_pred_length = (seq[k]>0).sum().long()
                tmp_pred_trace = trace_predicted[k, :tmp_pred_length]

                # choose only boxes not [0,0,1,1,1] in the ground truth
                nonzero_idx = torch.nonzero(tmp_gt_trace[:, 4] != 1).squeeze()
                tmp_gt_trace = tmp_gt_trace[nonzero_idx]
                if len(tmp_gt_trace.shape) < 2:  # if there is only one chosen box in this trace
                    tmp_gt_trace = tmp_gt_trace.unsqueeze(0)
                tmp_gt_trace = tmp_gt_trace.unsqueeze(0)
                tmp_pred_trace = tmp_pred_trace.unsqueeze(0)

                if tmp_pred_trace.shape[1] <= tmp_gt_trace.shape[1]:
                    tmp_trace1 = tmp_pred_trace
                    tmp_trace2 = tmp_gt_trace
                else:
                    tmp_trace1 = tmp_gt_trace
                    tmp_trace2 = tmp_pred_trace
                # processing in terms of segments of length 20
                seg_loss_list = []
                for seg_idx in range(np.ceil(tmp_trace1.shape[1] / 20).astype(int)):
                    tmp_const = 20. * tmp_trace2.shape[1] / tmp_trace1.shape[1]
                    seg_tmp_trace1 = tmp_trace1[:, seg_idx * 20:(seg_idx + 1) * 20, :4]
                    seg_tmp_trace2 = tmp_trace2[:, np.floor(seg_idx * tmp_const).astype(int): np.ceil(
                        (seg_idx + 1) * tmp_const).astype(int), :4]
                    D = torch.abs(seg_tmp_trace1.unsqueeze(2) - seg_tmp_trace2.unsqueeze(1)).mean(dim=-1)
                    seg_tmp_T = local_OT(D, window = 0)
                    seg_tmp_cost = (seg_tmp_T * D).sum() / seg_tmp_trace1.shape[1]
                    if not torch.isnan(seg_tmp_cost):
                        seg_loss_list.append(seg_tmp_cost.item())
                tmp_cost = np.mean(np.array(seg_loss_list))
                if not np.isnan(tmp_cost):
                    trace_cost.append(tmp_cost)
                print('trace LBM distance:', tmp_cost)

        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(fc_feats.shape[0]):
                print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
                print('--' * 10)
        sents = utils.decode_sequence(model.vocab, seq)
        print('both trace running ave LBM loss :', np.mean(np.array(trace_cost)))

        # ### save for visualization  # for visualization of trace_generation
        # for i in range(len(sents)):
        #     vis_img_id = data['infos'][i]['id']
        #     with open('./vis/both_generation_supplement/pred_caption/pred_caption_' + str(vis_img_id)+'.txt', 'w') as f:
        #         f.write(sents[i])
        #     np.save('./vis/both_generation_supplement/pred_trace/pred_trace_' + str(vis_img_id),
        #             trace_predicted[i, :, :4].detach().cpu().numpy())
        #     print(vis_img_id, trace_feats.shape)
        #     with open('./vis/both_generation_supplement/info.txt', 'a') as f:
        #         f.write('img_id:%d\n' %vis_img_id)
        #         f.close()
        # ############################

        # ### save for visualization  # for visualization of caption_generation
        # for i in range(len(sents)):
        #     vis_img_id = data['infos'][i]['id']
        #     tmp_dir = './vis/caption_generation_' + eval_kwargs['dataset_choice']
        #     if not os.path.exists(tmp_dir):
        #         os.makedirs(tmp_dir)
        #         os.makedirs(tmp_dir + '/pred_caption')
        #         os.makedirs(tmp_dir + '/gt_trace')
        #     with open('./vis/caption_generation_'+ eval_kwargs['dataset_choice'] +'/pred_caption/pred_caption_' + str(vis_img_id) + '.txt',
        #               'w') as f:
        #         f.write(sents[i])
        #     np.save('./vis/caption_generation_'+ eval_kwargs['dataset_choice'] +'/gt_trace/gt_trace_' + str(vis_img_id),
        #             trace_feats[i, :, :4].detach().cpu().numpy())
        #     print(vis_img_id, trace_feats.shape)
        #     with open('./vis/caption_generation_'+ eval_kwargs['dataset_choice'] +'/info.txt', 'a') as f:
        #         f.write('img_id:%s\n' % str(vis_img_id))
        #         f.close()
        # ############################


        for k, sent in enumerate(sents):
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' %(entry['image_id'], entry['caption']))

        if sample_n > 1:
            eval_split_n(model, n_predictions, [fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks, data], eval_kwargs)

        ix1 = data['bounds']['it_max']
        if num_images != -1:
            ix1 = min(ix1, num_images)
        else:
            num_images = ix1

        for i in range(n - ix1):
            predictions.pop()

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))

        if num_images >= 0 and n >= num_images:
            break

    if task == 'both':
        print('both trace total LBM loss:', np.mean(np.array(trace_cost)))

    lang_stats = None
    if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
        n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
    if not os.path.isdir('eval_results'):
        os.mkdir('eval_results')
    torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
    if lang_eval == 1:
        lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)

    # Switch back to training mode
    model.train()
    return loss_sum/loss_evals, predictions, lang_stats