kyutai/run_eval.py (222 lines of code) (raw):

import argparse import os import torch import evaluate from normalizer import data_utils import time from tqdm import tqdm import julius from moshi import models wer_metric = evaluate.load("wer") torch.set_float32_matmul_precision("high") def load_model(model_path): info = models.loaders.CheckpointInfo.from_hf_repo(model_path) mimi = info.get_mimi(device="cuda") tokenizer = info.get_text_tokenizer() lm = info.get_moshi( device="cuda", dtype=torch.bfloat16, ) lm_gen = models.LMGen(lm, temp=0, temp_text=0.0) padding_token_id = info.raw_config.get("text_padding_token_id", 3) # Putting in some conservative defaults audio_silence_prefix_seconds = info.stt_config.get( "audio_silence_prefix_seconds", 1.0 ) audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) return ( mimi, tokenizer, lm, lm_gen, padding_token_id, audio_silence_prefix_seconds, audio_delay_seconds, ) @torch.inference_mode def get_padded_batch( audios, sample_rates, before_padding: float, after_padding: float, frame_size: int ): sample_rate = 24_000 batch = [] max_len = -1 for audio, sr in zip(audios, sample_rates): audio = julius.resample.resample_frac(audio, old_sr=sr, new_sr=sample_rate) audio = torch.nn.functional.pad( audio, (int(before_padding * sample_rate), int(after_padding * sample_rate)) ) max_len = max(max_len, audio.shape[-1]) batch.append(audio) target = max_len if target % frame_size != 0: target = target + (frame_size - max_len % frame_size) batch = torch.stack( [ torch.nn.functional.pad(audio, (0, target - audio.shape[-1])) for audio in batch ] ) return batch def main(args): ( mimi, tokenizer, _lm, lm_gen, padding_token_id, audio_silence_prefix_seconds, audio_delay_seconds, ) = load_model(args.model_id) mimi_frame_size = mimi.frame_size def benchmark(batch): # Load audio inputs audios = [torch.from_numpy(audio["array"]) for audio in batch["audio"]] sample_rates = [ex["sampling_rate"] for ex in batch["audio"]] batch["audio_length_s"] = [ len(audio) / batch["audio"][0]["sampling_rate"] for audio in audios ] minibatch_size = len(audios) # Start timing start_time = time.time() padded_batch = get_padded_batch( audios, sample_rates, before_padding=audio_silence_prefix_seconds, after_padding=audio_delay_seconds, frame_size=mimi_frame_size, ) padded_batch = padded_batch.to(args.device).float() bsz = padded_batch.shape[0] text_tokens_acc = [] with mimi.streaming(bsz), lm_gen.streaming(bsz): for offset in range(0, padded_batch.shape[-1], mimi.frame_size): audio_chunk = padded_batch[:, offset : offset + mimi.frame_size].cuda() tokens = mimi.encode(audio_chunk[:, None, :]) text_tokens = lm_gen.step(tokens) text_tokens_acc.append(text_tokens) pred_tokens = torch.concat(text_tokens_acc, axis=-1).squeeze(dim=1) pred_tokens = torch.unbind(pred_tokens, dim=0) pred_text = [ tokenizer.decode(t[t > padding_token_id].cpu().numpy().tolist()) for t in pred_tokens ] # End timing runtime = time.time() - start_time # normalize by minibatch size since we want the per-sample time batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size] # normalize transcriptions with English normalizer batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text] batch["references"] = batch["norm_text"] return batch if args.warmup_steps is not None: warmup_dataset = data_utils.load_data(args) warmup_dataset = data_utils.prepare_data(warmup_dataset) num_warmup_samples = args.warmup_steps * args.batch_size if args.streaming: warmup_dataset = warmup_dataset.take(num_warmup_samples) else: warmup_dataset = warmup_dataset.select( range(min(num_warmup_samples, len(warmup_dataset))) ) warmup_dataset = iter( warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True) ) for _ in tqdm(warmup_dataset, desc="Warming up..."): continue dataset = data_utils.load_data(args) dataset = data_utils.prepare_data(dataset) if args.max_eval_samples is not None and args.max_eval_samples > 0: print(f"Subsampling dataset to first {args.max_eval_samples} samples!") if args.streaming: dataset = dataset.take(args.max_eval_samples) else: dataset = dataset.select(range(min(args.max_eval_samples, len(dataset)))) dataset = dataset.map( benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"], ) all_results = { "audio_length_s": [], "transcription_time_s": [], "predictions": [], "references": [], } result_iter = iter(dataset) for result in tqdm(result_iter, desc="Samples..."): for key in all_results: all_results[key].append(result[key]) # Write manifest results (WER and RTFX) manifest_path = data_utils.write_manifest( all_results["references"], all_results["predictions"], args.model_id, args.dataset_path, args.dataset, args.split, audio_length=all_results["audio_length_s"], transcription_time=all_results["transcription_time_s"], ) print("Results saved at path:", os.path.abspath(manifest_path)) wer = wer_metric.compute( references=all_results["references"], predictions=all_results["predictions"] ) wer = round(100 * wer, 2) rtfx = round( sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2 ) print("WER:", wer, "%", "RTFx:", rtfx) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers", ) parser.add_argument( "--dataset_path", type=str, default="esb/datasets", help="Dataset path. By default, it is `esb/datasets`", ) parser.add_argument( "--dataset", type=str, required=True, help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names " "can be found at `https://huggingface.co/datasets/esb/datasets`", ) parser.add_argument( "--split", type=str, default="test", help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", ) parser.add_argument( "--device", type=int, default=-1, help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", ) parser.add_argument( "--batch_size", type=int, default=1, help="Number of samples to go through each streamed batch.", ) parser.add_argument( "--max_eval_samples", type=int, default=None, help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", ) parser.add_argument( "--no-streaming", dest="streaming", action="store_false", help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", ) parser.add_argument( "--warmup_steps", type=int, default=10, help="Number of warm-up steps to run before launching the timed runs.", ) args = parser.parse_args() parser.set_defaults(streaming=False) main(args)