in scripts/eval_grd_anet_entities.py [0:0]
def precision_recall_util(self, mode='all'):
ref = self.ref
pred = self.pred
print('Number of videos in the reference: {}, number of videos in the submission: {}'.format(len(ref), len(pred)))
nlp = StanfordCoreNLP('tools/stanford-corenlp-full-2018-02-27')
props={'annotators': 'lemma','pipelineLanguage':'en', 'outputFormat':'json'}
vocab_in_split = set()
prec = defaultdict(list)
prec_per_sent = defaultdict(list)
for vid, anns in tqdm(ref.items()):
for seg, ann in anns['segments'].items():
if len(ann['frame_ind']) == 0 or vid not in pred or seg not in pred[vid]:
continue # do not penalize if sentence not annotated
prec_per_sent_tmp = [] # for each sentence
ref_bbox_all = torch.cat((torch.Tensor(ann['process_bnd_box']),
torch.Tensor(ann['frame_ind']).unsqueeze(-1)), dim=1) # 5-D coordinates
idx_in_sent = {}
for box_idx, cls_lst in enumerate(ann['process_clss']):
vocab_in_split.update(set(cls_lst))
for cls_idx, cls in enumerate(cls_lst):
idx_in_sent[cls] = idx_in_sent.get(cls, []) + [ann['process_idx'][box_idx][cls_idx]]
sent_idx = set(itertools.chain.from_iterable(ann['process_idx'])) # index of gt object words
exclude_obj = {json.loads(nlp.annotate(token, properties=props)
)['sentences'][0]['tokens'][0]['lemma']: 1 for token_idx, token in
enumerate(ann['tokens']
) if (token_idx not in sent_idx and token != '')}
for pred_idx, class_name in enumerate(pred[vid][seg]['clss']):
if class_name in idx_in_sent:
gt_idx = min(idx_in_sent[class_name]) # always consider the first match...
sel_idx = [idx for idx, i in enumerate(ann['process_idx']) if gt_idx in i]
ref_bbox = ref_bbox_all[sel_idx] # select matched boxes
assert (ref_bbox.size(0) > 0)
pred_bbox = torch.cat((torch.Tensor(pred[vid][seg]['bbox_for_all_frames'][pred_idx])[:, :4],
torch.Tensor(range(10)).unsqueeze(-1)), dim=1)
frm_mask = torch.from_numpy(get_frm_mask(pred_bbox[:, 4].numpy(),
ref_bbox[:, 4].numpy()).astype('uint8'))
overlap = bbox_overlaps_batch(pred_bbox[:, :5].unsqueeze(0),
ref_bbox[:, :5].unsqueeze(0), frm_mask.unsqueeze(0))
prec[class_name].append(1 if torch.max(overlap) > self.iou_thresh else 0)
prec_per_sent_tmp.append(1 if torch.max(overlap) > self.iou_thresh else 0)
elif json.loads(nlp.annotate(class_name, properties=props))['sentences'][0]['tokens'][0]['lemma'] in exclude_obj:
pass # do not penalize if gt object word not annotated (missed)
else:
if mode == 'all':
prec[class_name].append(0) # hallucinated object
prec_per_sent_tmp.append(0)
prec_per_sent[vid + seg] = prec_per_sent_tmp
nlp.close()
# recall
recall = defaultdict(list)
recall_per_sent = defaultdict(list)
for vid, anns in ref.items():
for seg, ann in anns['segments'].items():
if len(ann['frame_ind']) == 0:
# print('no annotation available')
continue
recall_per_sent_tmp = [] # for each sentence
ref_bbox_all = torch.cat((torch.Tensor(ann['process_bnd_box']), \
torch.Tensor(ann['frame_ind']).unsqueeze(-1)), dim=1) # 5-D coordinates
sent_idx = set(itertools.chain.from_iterable(ann['process_idx'])) # index of gt object words
for gt_idx in sent_idx:
sel_idx = [idx for idx, i in enumerate(ann['process_idx']) if gt_idx in i]
ref_bbox = ref_bbox_all[sel_idx] # select matched boxes
# Note that despite discouraged, a single word could be annotated across multiple boxes/frames
assert(ref_bbox.size(0) > 0)
class_name = ann['process_clss'][sel_idx[0]][ann['process_idx'][sel_idx[0]].index(gt_idx)]
if vid not in pred:
recall[class_name].append(0) # video not grounded
recall_per_sent_tmp.append(0)
elif seg not in pred[vid]:
recall[class_name].append(0) # segment not grounded
recall_per_sent_tmp.append(0)
elif class_name in pred[vid][seg]['clss']:
pred_idx = pred[vid][seg]['clss'].index(class_name) # always consider the first match...
pred_bbox = torch.cat((torch.Tensor(pred[vid][seg]['bbox_for_all_frames'][pred_idx])[:,:4], \
torch.Tensor(range(10)).unsqueeze(-1)), dim=1)
frm_mask = torch.from_numpy(get_frm_mask(pred_bbox[:, 4].numpy(), \
ref_bbox[:, 4].numpy()).astype('uint8'))
overlap = bbox_overlaps_batch(pred_bbox[:, :5].unsqueeze(0), \
ref_bbox[:, :5].unsqueeze(0), frm_mask.unsqueeze(0))
recall[class_name].append(1 if torch.max(overlap) > self.iou_thresh else 0)
recall_per_sent_tmp.append(1 if torch.max(overlap) > self.iou_thresh else 0)
else:
if mode == 'all':
recall[class_name].append(0) # object not grounded
recall_per_sent_tmp.append(0)
recall_per_sent[vid + seg] = recall_per_sent_tmp
return prec, recall, prec_per_sent, recall_per_sent, vocab_in_split