phi/run_eval.py (197 lines of code) (raw):

import argparse import os import torch from transformers import AutoModelForCausalLM, AutoProcessor, StoppingCriteria, StoppingCriteriaList import evaluate from normalizer import data_utils import time from tqdm import tqdm wer_metric = evaluate.load("wer") torch.set_float32_matmul_precision('high') class MultipleTokenBatchStoppingCriteria(StoppingCriteria): """Stopping criteria capable of receiving multiple stop-tokens and handling batched inputs.""" def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None: """Initialize the multiple token batch stopping criteria. Args: stop_tokens: Stop-tokens. batch_size: Batch size. """ self.stop_tokens = stop_tokens self.max_stop_tokens = stop_tokens.shape[-1] self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # Only gather the maximum number of inputs compatible with stop tokens # and checks whether generated inputs are equal to `stop_tokens` generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens) equal_generated_inputs = torch.all(generated_inputs, dim=2) # Mark the position where a stop token has been produced for each input in the batch, # but only if the corresponding entry is not already set sequence_idx = torch.any(equal_generated_inputs, dim=1) sequence_set_mask = self.stop_tokens_idx == 0 self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1] return torch.all(self.stop_tokens_idx) def main(args): model = AutoModelForCausalLM.from_pretrained( args.model_id, trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2", ).to(args.device) model.eval() processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) user = "<|user|>" assistant = "<|assistant|>" prompt_suffix = "<|end|>" prompt = f"{user}<|audio_1|>{args.user_prompt}{prompt_suffix}{assistant}" gen_kwargs = {"max_new_tokens": args.max_new_tokens, "num_beams": args.num_beams} stop_tokens = [prompt_suffix, processor.tokenizer.eos_token] stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"] stop_tokens_ids = stop_tokens_ids.to(model.device) def benchmark(batch, min_new_tokens=None): # Load audio inputs audios = [(audio["array"], audio["sampling_rate"]) for audio in batch["audio"]] minibatch_size = len(audios) gen_kwargs["stopping_criteria"] = StoppingCriteriaList( [MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=args.num_beams * minibatch_size)] ) # START TIMING start_time = time.time() with torch.autocast(model.device.type, enabled=True): inputs = processor(text=[prompt] * minibatch_size, audios=audios, return_tensors="pt").to(args.device) # Model Inference pred_ids = model.generate( **inputs, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, **gen_kwargs, min_new_tokens=min_new_tokens, ) # Gather the sequence index of the stop token stop_tokens_idx = gen_kwargs["stopping_criteria"][0].stop_tokens_idx.reshape(minibatch_size, -1)[:, 0] # If a stop token was produced, we need to remove its length from the found index, # however there might be a chance that the stop token was not produced and the index # returned is the length of the generated sequence stop_tokens_idx = torch.where( stop_tokens_idx > 0, stop_tokens_idx - stop_tokens_ids.shape[-1], pred_ids.shape[-1], ) # Convert token ids to text transcription pred_text = [ processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx], skip_special_tokens=True, clean_up_tokenization_spaces=False) for _pred_ids, _stop_tokens_idx in zip(pred_ids, stop_tokens_idx) ] # END TIMING runtime = time.time() - start_time # normalize by minibatch size since we want the per-sample time batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size] # normalize transcriptions with English normalizer batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text] batch["references"] = batch["norm_text"] return batch if args.warmup_steps is not None: dataset = data_utils.load_data(args) dataset = data_utils.prepare_data(dataset) num_warmup_samples = args.warmup_steps * args.batch_size if args.streaming: warmup_dataset = dataset.take(num_warmup_samples) else: warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset)))) warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens})) for _ in tqdm(warmup_dataset, desc="Warming up..."): continue dataset = data_utils.load_data(args) if args.max_eval_samples is not None and args.max_eval_samples > 0: print(f"Subsampling dataset to first {args.max_eval_samples} samples!") if args.streaming: dataset = dataset.take(args.max_eval_samples) else: dataset = dataset.select(range(min(args.max_eval_samples, len(dataset)))) dataset = data_utils.prepare_data(dataset) dataset = dataset.map( benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"], ) all_results = { "audio_length_s": [], "transcription_time_s": [], "predictions": [], "references": [], } result_iter = iter(dataset) for result in tqdm(result_iter, desc="Samples..."): for key in all_results: all_results[key].append(result[key]) # Write manifest results (WER and RTFX) manifest_path = data_utils.write_manifest( all_results["references"], all_results["predictions"], args.model_id, args.dataset_path, args.dataset, args.split, audio_length=all_results["audio_length_s"], transcription_time=all_results["transcription_time_s"], ) print("Results saved at path:", os.path.abspath(manifest_path)) wer = wer_metric.compute( references=all_results["references"], predictions=all_results["predictions"] ) wer = round(100 * wer, 2) rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2) print("WER:", wer, "%", "RTFx:", rtfx) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers", ) parser.add_argument( "--dataset_path", type=str, default="esb/datasets", help="Dataset path. By default, it is `esb/datasets`", ) parser.add_argument( "--dataset", type=str, required=True, help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names " "can be found at `https://huggingface.co/datasets/esb/datasets`", ) parser.add_argument( "--split", type=str, default="test", help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", ) parser.add_argument( "--device", type=int, default=-1, help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", ) parser.add_argument( "--batch_size", type=int, default=16, help="Number of samples to go through each streamed batch.", ) parser.add_argument( "--num_beams", type=int, default=1, help="Number of beams for beam search.", ) parser.add_argument( "--max_eval_samples", type=int, default=None, help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", ) parser.add_argument( "--no-streaming", dest="streaming", action="store_false", help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", ) parser.add_argument( "--max_new_tokens", type=int, default=None, help="Maximum number of tokens to generate (for auto-regressive models).", ) parser.add_argument( "--warmup_steps", type=int, default=2, help="Number of warm-up steps to run before launching the timed runs.", ) parser.add_argument( "--user_prompt", type=str, default="Transcribe the audio clip into text.", help="User prompt string.", ) args = parser.parse_args() parser.set_defaults(streaming=False) main(args)