in scripts/adapet/ADAPET/src/eval/Scorer.py [0:0]
def add_batch(self, list_idx, list_pred_lbl, list_true_lbl, lbl_logits, list_candidates=None):
'''
Keeps track of the accuracy of current batch
:param logits:
:param true_lbl:
:return:
'''
self.list_logits.append(lbl_logits)
lbl_logits = lbl_logits.tolist()
if torch.is_tensor(list_true_lbl):
list_true_lbl = list_true_lbl.cpu().detach().numpy()
if list_candidates is not None:
for idx, pred_lbl, true_lbl, logit, cnd in zip(list_idx, list_pred_lbl.cpu().detach().numpy(), list_true_lbl, lbl_logits, list_candidates):
if idx in self.dict_idx2logits_lbl:
self.dict_idx2logits_lbl[idx].append((pred_lbl, true_lbl, logit, cnd))
else:
self.dict_idx2logits_lbl[idx] = [(pred_lbl, true_lbl, logit, cnd)]
else:
for idx, pred_lbl, true_lbl, logit in zip(list_idx, list_pred_lbl.cpu().detach().numpy(), list_true_lbl, lbl_logits):
if idx in self.dict_idx2logits_lbl:
self.dict_idx2logits_lbl[idx].append((pred_lbl, true_lbl, logit))
else:
self.dict_idx2logits_lbl[idx] = [(pred_lbl, true_lbl, logit)]