def wer_and_cer()

in metrics/xtreme_s/xtreme_s.py [0:0]


def wer_and_cer(preds, labels, concatenate_texts, config_name):
    try:
        import jiwer
    except ImportError:
        raise ValueError(
            f"jiwer has to be installed in order to apply the wer metric for {config_name}."
            "You can install it via `pip install jiwer`."
        )

    if hasattr(jiwer, "compute_measures"):
        if concatenate_texts:
            wer = jiwer.compute_measures(labels, preds)["wer"]

            cer = jiwer.compute_measures(
                labels, preds, truth_transform=cer_transform, hypothesis_transform=cer_transform
            )["wer"]
            return {"wer": wer, "cer": cer}
        else:

            def compute_score(preds, labels, score_type="wer"):
                incorrect = 0
                total = 0
                for prediction, reference in zip(preds, labels):
                    if score_type == "wer":
                        measures = jiwer.compute_measures(reference, prediction)
                    elif score_type == "cer":
                        measures = jiwer.compute_measures(
                            reference, prediction, truth_transform=cer_transform, hypothesis_transform=cer_transform
                        )
                    incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
                    total += measures["substitutions"] + measures["deletions"] + measures["hits"]
                return incorrect / total

            return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
    else:
        if concatenate_texts:
            wer = jiwer.process_words(labels, preds).wer

            cer = jiwer.process_characters(labels, preds).cer
            return {"wer": wer, "cer": cer}
        else:

            def compute_score(preds, labels, score_type="wer"):
                incorrect = 0
                total = 0
                for prediction, reference in zip(preds, labels):
                    if score_type == "wer":
                        measures = jiwer.process_words(reference, prediction)
                    elif score_type == "cer":
                        measures = jiwer.process_characters(reference, prediction)
                    incorrect += measures.substitutions + measures.deletions + measures.insertions
                    total += measures.substitutions + measures.deletions + measures.hits
                return incorrect / total

            return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}