in notebooks/src/code/data/ner.py [0:0]
def compute_metrics(p: Any) -> Dict[str, Real]:
probs_raw, labels_raw = p
predicted_class_ids_raw = np.argmax(probs_raw, axis=2)
# Override padding token predictions to ignore value:
non_pad_labels = labels_raw != pad_token_label_id
predicted_class_ids_raw = np.where(
non_pad_labels,
predicted_class_ids_raw,
pad_token_label_id,
)
# Update predictions by label:
unique_labels, unique_counts = np.unique(predicted_class_ids_raw, return_counts=True)
# Accuracy ignoring PAD, CLS and SEP tokens:
n_tokens_by_example = non_pad_labels.sum(axis=1)
n_tokens_total = n_tokens_by_example.sum()
n_correct_by_example = np.logical_and(
labels_raw == predicted_class_ids_raw, non_pad_labels
).sum(axis=1)
acc_by_example = np.true_divide(n_correct_by_example, n_tokens_by_example)
# Accuracy ignoring PAD, CLS, SEP tokens *and* tokens where both pred and actual classes
# are 'other':
focus_labels = np.logical_and(
non_pad_labels,
np.logical_or(
labels_raw != other_class_label,
predicted_class_ids_raw != other_class_label,
),
)
n_focus_tokens_by_example = focus_labels.sum(axis=1)
n_focus_correct_by_example = np.logical_and(
labels_raw == predicted_class_ids_raw,
focus_labels,
).sum(axis=1)
focus_acc_by_example = np.true_divide(
n_focus_correct_by_example[n_focus_tokens_by_example != 0],
n_focus_tokens_by_example[n_focus_tokens_by_example != 0],
)
logger.info(
"Evaluation class prediction ratios: {}".format(
{
unique_labels[ix]: unique_counts[ix] / n_tokens_total
for ix in range(len(unique_counts))
if unique_labels[ix] != pad_token_label_id
}
)
)
n_examples = probs_raw.shape[0]
acc = acc_by_example.sum() / n_examples
focus_acc = focus_acc_by_example.sum() / n_examples
return {
"n_examples": n_examples,
"acc": acc,
"focus_acc": focus_acc,
# By nature of the metric, focus_acc can sometimes take a few epochs to move away from
# 0.0. Since acc and focus_acc are both 0-1, we can define this metric to show early
# improvement (thus prevent early stopping) but still target focus_acc later:
"focus_else_acc_minus_one": focus_acc if focus_acc > 0 else acc - 1,
}