in captioning/utils/eval_utils.py [0:0]
def eval_split(model, crit, loader, eval_kwargs={}):
verbose = eval_kwargs.get('verbose', True)
verbose_beam = eval_kwargs.get('verbose_beam', 0)
verbose_loss = eval_kwargs.get('verbose_loss', 1)
num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
split = eval_kwargs.get('split', 'val')
lang_eval = eval_kwargs.get('language_eval', 0)
dataset = eval_kwargs.get('dataset', 'coco')
beam_size = eval_kwargs.get('beam_size', 1)
sample_n = eval_kwargs.get('sample_n', 1)
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
device = eval_kwargs.get('device', 'cuda')
# Make sure in the evaluation mode
model.eval()
loader.reset_iterator(split)
n = 0
loss = 0
loss_sum = 0
loss_evals = 1e-8
predictions = []
n_predictions = [] # when sample_n > 1
while True:
data = loader.get_batch(split)
# print('In eval_utils:', split)###zihang
print(num_images)
n = n + len(data['infos'])
tmp = [data['fc_feats'], data['att_feats'], data['trace_feats'], data['box_feats'], data['labels'], data['masks'], data['att_masks'], data['trace_masks']]
tmp = [_.to(device) if _ is not None else _ for _ in tmp]
fc_feats, att_feats, trace_feats, box_feats, labels, masks, att_masks, trace_masks = tmp
if labels is not None and verbose_loss:
# forward the model to get loss
with torch.no_grad():
loss = crit(model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1], att_masks, trace_masks, 'caption'), labels[..., 1:], masks[..., 1:]).item()
loss_sum = loss_sum + loss
loss_evals = loss_evals + 1
# forward the model to also get generated samples for each image
with torch.no_grad():
tmp_eval_kwargs = eval_kwargs.copy()
tmp_eval_kwargs.update({'sample_n': 1})
seq, seq_logprobs = model(fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks, opt=tmp_eval_kwargs, mode='sample')
seq = seq.data
entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
# Print beam search
if beam_size > 1 and verbose_beam:
for i in range(fc_feats.shape[0]):
print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
print('--' * 10)
sents = utils.decode_sequence(model.vocab, seq)
for k, sent in enumerate(sents):
entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
if eval_kwargs.get('dump_path', 0) == 1:
entry['file_name'] = data['infos'][k]['file_path']
predictions.append(entry)
if eval_kwargs.get('dump_images', 0) == 1:
# dump the raw image to vis/ folder
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
print(cmd)
os.system(cmd)
if verbose:
print('image %s: %s' %(entry['image_id'], entry['caption']))
if sample_n > 1:
eval_split_n(model, n_predictions, [fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks, data], eval_kwargs)
# ix0 = data['bounds']['it_pos_now']
ix1 = data['bounds']['it_max']
# print('ix1', ix1)###zihang
if num_images != -1:
ix1 = min(ix1, num_images)
else:
num_images = ix1
# print('len:', len(predictions), n, ix1, split, num_images) ###zihang
for i in range(n - ix1):
predictions.pop()
if verbose:
print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))
if num_images >= 0 and n >= num_images:
break
lang_stats = None
if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
if not os.path.isdir('eval_results'):
os.mkdir('eval_results')
torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
if lang_eval == 1:
lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)
# Switch back to training mode
model.train()
return loss_sum/loss_evals, predictions, lang_stats