normalizer/eval_utils.py (143 lines of code) (raw):

import os import glob import json import evaluate from collections import defaultdict def read_manifest(manifest_path: str): """ Reads a manifest file (jsonl format) and returns a list of dictionaries containing samples. """ data = [] with open(manifest_path, "r", encoding="utf-8") as f: for line in f: if len(line) > 0: datum = json.loads(line) data.append(datum) return data def write_manifest( references: list, transcriptions: list, model_id: str, dataset_path: str, dataset_name: str, split: str, audio_length: list = None, transcription_time: list = None, ): """ Writes a manifest file (jsonl format) and returns the path to the file. Args: references: Ground truth reference texts. transcriptions: Model predicted transcriptions. model_id: String identifier for the model. dataset_path: Path to the dataset. dataset_name: Name of the dataset. split: Dataset split name. audio_length: Length of each audio sample in seconds. transcription_time: Transcription time of each sample in seconds. Returns: Path to the manifest file. """ model_id = model_id.replace("/", "-") dataset_path = dataset_path.replace("/", "-") dataset_name = dataset_name.replace("/", "-") if len(references) != len(transcriptions): raise ValueError( f"The number of samples in `references` ({len(references)}) " f"must match `transcriptions` ({len(transcriptions)})." ) if audio_length is not None and len(audio_length) != len(references): raise ValueError( f"The number of samples in `audio_length` ({len(audio_length)}) " f"must match `references` ({len(references)})." ) if transcription_time is not None and len(transcription_time) != len(references): raise ValueError( f"The number of samples in `transcription_time` ({len(transcription_time)}) " f"must match `references` ({len(references)})." ) audio_length = ( audio_length if audio_length is not None else len(references) * [None] ) transcription_time = ( transcription_time if transcription_time is not None else len(references) * [None] ) basedir = "./results/" if not os.path.exists(basedir): os.makedirs(basedir) manifest_path = os.path.join( basedir, f"MODEL_{model_id}_DATASET_{dataset_path}_{dataset_name}_{split}.jsonl" ) with open(manifest_path, "w", encoding="utf-8") as f: for idx, (text, transcript, audio_length, transcription_time) in enumerate( zip(references, transcriptions, audio_length, transcription_time) ): datum = { "audio_filepath": f"sample_{idx}", # dummy value for Speech Data Processor "duration": audio_length, "time": transcription_time, "text": text, "pred_text": transcript, } f.write(f"{json.dumps(datum, ensure_ascii=False)}\n") return manifest_path def score_results(directory: str, model_id: str = None): """ Scores all result files in a directory and returns a composite score over all evaluated datasets. Args: directory: Path to the result directory, containing one or more jsonl files. model_id: Optional, model name to filter out result files based on model name. Returns: Composite score over all evaluated datasets and a dictionary of all results. """ # Strip trailing slash if directory.endswith(os.pathsep): directory = directory[:-1] # Find all result files in the directory result_files = list(glob.glob(f"{directory}/**/*.jsonl", recursive=True)) result_files = list(sorted(result_files)) # Filter files belonging to a specific model id if model_id is not None and model_id != "": print("Filtering models by id:", model_id) model_id = model_id.replace("/", "-") result_files = [fp for fp in result_files if model_id in fp] # Check if any result files were found if len(result_files) == 0: raise ValueError(f"No result files found in {directory}") # Utility function to parse the file path and extract model id, dataset path, dataset name and split def parse_filepath(fp: str): model_index = fp.find("MODEL_") fp = fp[model_index:] ds_index = fp.find("DATASET_") model_id = fp[:ds_index].replace("MODEL_", "").rstrip("_") author_index = model_id.find("-") model_id = model_id[:author_index] + "/" + model_id[author_index + 1 :] ds_fp = fp[ds_index:] dataset_id = ds_fp.replace("DATASET_", "").rstrip(".jsonl") return model_id, dataset_id # Compute WER results per dataset, and RTFx over all datasets results = {} wer_metric = evaluate.load("wer") for result_file in result_files: manifest = read_manifest(result_file) model_id_of_file, dataset_id = parse_filepath(result_file) references = [datum["text"] for datum in manifest] predictions = [datum["pred_text"] for datum in manifest] time = [datum["time"] for datum in manifest] duration = [datum["duration"] for datum in manifest] compute_rtfx = all(time) and all(duration) wer = wer_metric.compute(references=references, predictions=predictions) wer = round(100 * wer, 2) if compute_rtfx: audio_length = sum(duration) inference_time = sum(time) rtfx = round(sum(duration) / sum(time), 4) else: audio_length = inference_time = rtfx = None result_key = f"{model_id_of_file} | {dataset_id}" results[result_key] = {"wer": wer, "audio_length": audio_length, "inference_time": inference_time, "rtfx": rtfx} print("*" * 80) print("Results per dataset:") print("*" * 80) for k, v in results.items(): metrics = f"{k}: WER = {v['wer']:0.2f} %" if v["rtfx"] is not None: metrics += f", RTFx = {v['rtfx']:0.2f}" print(metrics) # composite WER should be computed over all datasets and with the same key composite_wer = defaultdict(float) composite_audio_length = defaultdict(float) composite_inference_time = defaultdict(float) count_entries = defaultdict(int) for k, v in results.items(): key = k.split("|")[0].strip() composite_wer[key] += v["wer"] if v["rtfx"] is not None: composite_audio_length[key] += v["audio_length"] composite_inference_time[key] += v["inference_time"] else: composite_audio_length[key] = composite_inference_time[key] = None count_entries[key] += 1 # normalize scores & print print() print("*" * 80) print("Composite Results:") print("*" * 80) for k, v in composite_wer.items(): wer = v / count_entries[k] print(f"{k}: WER = {wer:0.2f} %") for k in composite_audio_length: if composite_audio_length[k] is not None: rtfx = composite_audio_length[k] / composite_inference_time[k] print(f"{k}: RTFx = {rtfx:0.2f}") print("*" * 80) return composite_wer, results