nemo_asr/run_eval.py (159 lines of code) (raw):

import argparse import io import os import torch import evaluate import soundfile from tqdm import tqdm from normalizer import data_utils import numpy as np from nemo.collections.asr.models import ASRModel import time wer_metric = evaluate.load("wer") def main(args): DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache") DATASET_NAME = args.dataset SPLIT_NAME = args.split CACHE_DIR = os.path.join(DATA_CACHE_DIR, DATASET_NAME, SPLIT_NAME) if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR) if args.device >= 0: device = torch.device(f"cuda:{args.device}") compute_dtype=torch.bfloat16 else: device = torch.device("cpu") compute_dtype=torch.float32 if args.model_id.endswith(".nemo"): asr_model = ASRModel.restore_from(args.model_id, map_location=device) else: asr_model = ASRModel.from_pretrained(args.model_id, map_location=device) # type: ASRModel asr_model.to(compute_dtype) asr_model.eval() dataset = data_utils.load_data(args) def download_audio_files(batch): # download audio files and write the paths, transcriptions and durations to a manifest file audio_paths = [] durations = [] for id, sample in zip(batch["id"], batch["audio"]): # first step added here to make ID and wav filenames unique # several datasets like earnings22 have a hierarchical structure # for eg. earnings22/test/4432298/281.wav, earnings22/test/4450488/281.wav # lhotse uses the filename (281.wav) here as unique ID to create and name cuts # ref: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/collation.py#L186 id = id.replace('/', '_').removesuffix('.wav') audio_path = os.path.join(CACHE_DIR, f"{id}.wav") if "array" in sample: audio_array = np.float32(sample["array"]) sample_rate = 16000 elif "bytes" in sample: # added to be compatible with latest datasets library (3.x.x) that produces byte stream with io.BytesIO(sample["bytes"]) as audio_file: audio_array, sample_rate = soundfile.read(audio_file, dtype="float32") else: raise ValueError("Sample must have either 'array' or 'bytes' key") if not os.path.exists(audio_path): os.makedirs(os.path.dirname(audio_path), exist_ok=True) soundfile.write(audio_path, audio_array, sample_rate) audio_paths.append(audio_path) durations.append(len(audio_array) / sample_rate) batch["references"] = batch["norm_text"] batch["audio_filepaths"] = audio_paths batch["durations"] = durations return batch if args.max_eval_samples is not None and args.max_eval_samples > 0: print(f"Subsampling dataset to first {args.max_eval_samples} samples !") dataset = dataset.take(args.max_eval_samples) dataset = data_utils.prepare_data(dataset) if asr_model.cfg.decoding.strategy != "beam": asr_model.cfg.decoding.strategy = "greedy_batch" asr_model.change_decoding_strategy(asr_model.cfg.decoding) # prepraing the offline dataset dataset = dataset.map(download_audio_files, batch_size=args.batch_size, batched=True, remove_columns=["audio"]) # Write manifest from daraset batch using json and keys audio_filepath, duration, text all_data = { "audio_filepaths": [], "durations": [], "references": [], } data_itr = iter(dataset) for data in tqdm(data_itr, desc="Downloading Samples"): for key in all_data: all_data[key].append(data[key]) # Sort audio_filepaths and references based on durations values sorted_indices = sorted(range(len(all_data["durations"])), key=lambda k: all_data["durations"][k], reverse=True) all_data["audio_filepaths"] = [all_data["audio_filepaths"][i] for i in sorted_indices] all_data["references"] = [all_data["references"][i] for i in sorted_indices] all_data["durations"] = [all_data["durations"][i] for i in sorted_indices] total_time = 0 for _ in range(2): # warmup once and calculate rtf if _ == 0: audio_files = all_data["audio_filepaths"][:args.batch_size * 4] # warmup with 4 batches else: audio_files = all_data["audio_filepaths"] start_time = time.time() with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype), torch.inference_mode(), torch.no_grad(): if 'canary' in args.model_id: transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1) else: transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1) end_time = time.time() if _ == 1: total_time += end_time - start_time total_time = total_time # normalize transcriptions with English normalizer if isinstance(transcriptions, tuple) and len(transcriptions) == 2: transcriptions = transcriptions[0] predictions = [data_utils.normalizer(pred.text) for pred in transcriptions] avg_time = total_time / len(all_data["audio_filepaths"]) # Write manifest results (WER and RTFX) manifest_path = data_utils.write_manifest( all_data["references"], predictions, args.model_id, args.dataset_path, args.dataset, args.split, audio_length=all_data["durations"], transcription_time=[avg_time] * len(all_data["audio_filepaths"]), ) print("Results saved at path:", os.path.abspath(manifest_path)) wer = wer_metric.compute(references=all_data['references'], predictions=predictions) wer = round(100 * wer, 2) # transcription_time = sum(all_results["transcription_time"]) audio_length = sum(all_data["durations"]) rtfx = audio_length / total_time rtfx = round(rtfx, 2) print("RTFX:", rtfx) print("WER:", wer, "%") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_id", type=str, required=True, help="Model identifier. Should be loadable with NVIDIA NeMo.", ) parser.add_argument( '--dataset_path', type=str, default='esb/datasets', help='Dataset path. By default, it is `esb/datasets`' ) parser.add_argument( "--dataset", type=str, required=True, help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names " "can be found at `https://huggingface.co/datasets/esb/datasets`", ) parser.add_argument( "--split", type=str, default="test", help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", ) parser.add_argument( "--device", type=int, default=-1, help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", ) parser.add_argument( "--batch_size", type=int, default=32, help="Number of samples to go through each streamed batch.", ) parser.add_argument( "--max_eval_samples", type=int, default=None, help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", ) parser.add_argument( "--no-streaming", dest='streaming', action="store_false", help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", ) args = parser.parse_args() parser.set_defaults(streaming=True) main(args)