in expanded_checklist/checklist/core_record.py [0:0]
def accumulate_ner_probs(self):
"""
Accumulate the probabilities for each class
e.g. sum the probabilities for B-ORG, I-ORG, L-ORG.
It alters the fields in the core_record.
IMPORTANT: this also changes predictions and labels (makes them less
fine-grained, e.g. ORG instead of B-ORG)
"""
new_softmax_vocab = self.get_classes()
label_vocab = self.label_vocab
# changing confs can change predicted classes (e.g. entity types)
# althouth it's quite unlikely
changed_pred_class = 0
all_tokens = 0
def acc_probs_for_token(token_confs, token_pred):
tag2prob = defaultdict(int)
for p, prob in enumerate(token_confs):
# skip the type of label B-, I- etc.
label = get_class_from_seq_label(label_vocab[p])
tag2prob[label] += prob
new_confs = np.array([tag2prob[x] for x in new_softmax_vocab])
new_pred_idx = np.argmax(new_confs)
new_pred = new_softmax_vocab[new_pred_idx]
return new_confs, new_pred
# this loop works for both grouped and ungrouped data
for e, example_confs in enumerate(self.confs):
if type(self.confs[e]) == np.ndarray:
self.confs[e] = self.confs[e].tolist()
for g, group_example_confs in enumerate(example_confs):
if type(self.confs[e][g]) == np.ndarray:
self.confs[e][g] = self.confs[e][g].tolist()
if is_1d_list(group_example_confs[0]):
for t, token_confs in enumerate(group_example_confs):
new_confs, new_pred = acc_probs_for_token(
token_confs, self.preds[e][g][t])
current_pred_class =\
get_class_from_seq_label(self.preds[e][g][t])
self.confs[e][g][t] = new_confs
self.preds[e][g][t] = new_pred
if new_pred != current_pred_class:
changed_pred_class += 1
all_tokens += 1
else:
# for each group there are many versions of each example
for v, group_example_version_confs in enumerate(
group_example_confs):
if type(self.confs[e][g][v]) == np.ndarray:
self.confs[e][g][v] = self.confs[e][g][v].tolist()
for t, token_confs in enumerate(
group_example_version_confs):
new_confs, new_pred = acc_probs_for_token(
token_confs, self.preds[e][g][v][t])
current_pred_class = get_class_from_seq_label(
self.preds[e][g][v][t])
self.confs[e][g][v][t] = new_confs
self.preds[e][g][v][t] = new_pred
if new_pred != current_pred_class:
changed_pred_class += 1
all_tokens += 1
if changed_pred_class > 0:
logger.warning(f'Changed prediction for \
{changed_pred_class}/{all_tokens} tokens after merging confs \
for each entity type.')
self.label_vocab = new_softmax_vocab
self._simplify_seq_labels()