def compute_metrics()

in notebooks/src/code/data/ner.py [0:0]


    def compute_metrics(p: Any) -> Dict[str, Real]:
        probs_raw, labels_raw = p
        predicted_class_ids_raw = np.argmax(probs_raw, axis=2)

        # Override padding token predictions to ignore value:
        non_pad_labels = labels_raw != pad_token_label_id
        predicted_class_ids_raw = np.where(
            non_pad_labels,
            predicted_class_ids_raw,
            pad_token_label_id,
        )

        # Update predictions by label:
        unique_labels, unique_counts = np.unique(predicted_class_ids_raw, return_counts=True)

        # Accuracy ignoring PAD, CLS and SEP tokens:
        n_tokens_by_example = non_pad_labels.sum(axis=1)
        n_tokens_total = n_tokens_by_example.sum()
        n_correct_by_example = np.logical_and(
            labels_raw == predicted_class_ids_raw, non_pad_labels
        ).sum(axis=1)
        acc_by_example = np.true_divide(n_correct_by_example, n_tokens_by_example)

        # Accuracy ignoring PAD, CLS, SEP tokens *and* tokens where both pred and actual classes
        # are 'other':
        focus_labels = np.logical_and(
            non_pad_labels,
            np.logical_or(
                labels_raw != other_class_label,
                predicted_class_ids_raw != other_class_label,
            ),
        )
        n_focus_tokens_by_example = focus_labels.sum(axis=1)
        n_focus_correct_by_example = np.logical_and(
            labels_raw == predicted_class_ids_raw,
            focus_labels,
        ).sum(axis=1)
        focus_acc_by_example = np.true_divide(
            n_focus_correct_by_example[n_focus_tokens_by_example != 0],
            n_focus_tokens_by_example[n_focus_tokens_by_example != 0],
        )
        logger.info(
            "Evaluation class prediction ratios: {}".format(
                {
                    unique_labels[ix]: unique_counts[ix] / n_tokens_total
                    for ix in range(len(unique_counts))
                    if unique_labels[ix] != pad_token_label_id
                }
            )
        )
        n_examples = probs_raw.shape[0]
        acc = acc_by_example.sum() / n_examples
        focus_acc = focus_acc_by_example.sum() / n_examples
        return {
            "n_examples": n_examples,
            "acc": acc,
            "focus_acc": focus_acc,
            # By nature of the metric, focus_acc can sometimes take a few epochs to move away from
            # 0.0. Since acc and focus_acc are both 0-1, we can define this metric to show early
            # improvement (thus prevent early stopping) but still target focus_acc later:
            "focus_else_acc_minus_one": focus_acc if focus_acc > 0 else acc - 1,
        }