phi/run_eval.py (197 lines of code) (raw):
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, StoppingCriteria, StoppingCriteriaList
import evaluate
from normalizer import data_utils
import time
from tqdm import tqdm
wer_metric = evaluate.load("wer")
torch.set_float32_matmul_precision('high')
class MultipleTokenBatchStoppingCriteria(StoppingCriteria):
"""Stopping criteria capable of receiving multiple stop-tokens and handling batched inputs."""
def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None:
"""Initialize the multiple token batch stopping criteria.
Args:
stop_tokens: Stop-tokens.
batch_size: Batch size.
"""
self.stop_tokens = stop_tokens
self.max_stop_tokens = stop_tokens.shape[-1]
self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# Only gather the maximum number of inputs compatible with stop tokens
# and checks whether generated inputs are equal to `stop_tokens`
generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens)
equal_generated_inputs = torch.all(generated_inputs, dim=2)
# Mark the position where a stop token has been produced for each input in the batch,
# but only if the corresponding entry is not already set
sequence_idx = torch.any(equal_generated_inputs, dim=1)
sequence_set_mask = self.stop_tokens_idx == 0
self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1]
return torch.all(self.stop_tokens_idx)
def main(args):
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
torch_dtype="auto",
_attn_implementation="flash_attention_2",
).to(args.device)
model.eval()
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
user = "<|user|>"
assistant = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user}<|audio_1|>{args.user_prompt}{prompt_suffix}{assistant}"
gen_kwargs = {"max_new_tokens": args.max_new_tokens, "num_beams": args.num_beams}
stop_tokens = [prompt_suffix, processor.tokenizer.eos_token]
stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"]
stop_tokens_ids = stop_tokens_ids.to(model.device)
def benchmark(batch, min_new_tokens=None):
# Load audio inputs
audios = [(audio["array"], audio["sampling_rate"]) for audio in batch["audio"]]
minibatch_size = len(audios)
gen_kwargs["stopping_criteria"] = StoppingCriteriaList(
[MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=args.num_beams * minibatch_size)]
)
# START TIMING
start_time = time.time()
with torch.autocast(model.device.type, enabled=True):
inputs = processor(text=[prompt] * minibatch_size, audios=audios, return_tensors="pt").to(args.device)
# Model Inference
pred_ids = model.generate(
**inputs,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
**gen_kwargs,
min_new_tokens=min_new_tokens,
)
# Gather the sequence index of the stop token
stop_tokens_idx = gen_kwargs["stopping_criteria"][0].stop_tokens_idx.reshape(minibatch_size, -1)[:, 0]
# If a stop token was produced, we need to remove its length from the found index,
# however there might be a chance that the stop token was not produced and the index
# returned is the length of the generated sequence
stop_tokens_idx = torch.where(
stop_tokens_idx > 0,
stop_tokens_idx - stop_tokens_ids.shape[-1],
pred_ids.shape[-1],
)
# Convert token ids to text transcription
pred_text = [
processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx], skip_special_tokens=True, clean_up_tokenization_spaces=False)
for _pred_ids, _stop_tokens_idx in zip(pred_ids, stop_tokens_idx)
]
# END TIMING
runtime = time.time() - start_time
# normalize by minibatch size since we want the per-sample time
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
# normalize transcriptions with English normalizer
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
batch["references"] = batch["norm_text"]
return batch
if args.warmup_steps is not None:
dataset = data_utils.load_data(args)
dataset = data_utils.prepare_data(dataset)
num_warmup_samples = args.warmup_steps * args.batch_size
if args.streaming:
warmup_dataset = dataset.take(num_warmup_samples)
else:
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
for _ in tqdm(warmup_dataset, desc="Warming up..."):
continue
dataset = data_utils.load_data(args)
if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
if args.streaming:
dataset = dataset.take(args.max_eval_samples)
else:
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
dataset = data_utils.prepare_data(dataset)
dataset = dataset.map(
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
)
all_results = {
"audio_length_s": [],
"transcription_time_s": [],
"predictions": [],
"references": [],
}
result_iter = iter(dataset)
for result in tqdm(result_iter, desc="Samples..."):
for key in all_results:
all_results[key].append(result[key])
# Write manifest results (WER and RTFX)
manifest_path = data_utils.write_manifest(
all_results["references"],
all_results["predictions"],
args.model_id,
args.dataset_path,
args.dataset,
args.split,
audio_length=all_results["audio_length_s"],
transcription_time=all_results["transcription_time_s"],
)
print("Results saved at path:", os.path.abspath(manifest_path))
wer = wer_metric.compute(
references=all_results["references"], predictions=all_results["predictions"]
)
wer = round(100 * wer, 2)
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
print("WER:", wer, "%", "RTFx:", rtfx)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
required=True,
help="Model identifier. Should be loadable with 🤗 Transformers",
)
parser.add_argument(
"--dataset_path",
type=str,
default="esb/datasets",
help="Dataset path. By default, it is `esb/datasets`",
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
"can be found at `https://huggingface.co/datasets/esb/datasets`",
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
)
parser.add_argument(
"--device",
type=int,
default=-1,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Number of samples to go through each streamed batch.",
)
parser.add_argument(
"--num_beams",
type=int,
default=1,
help="Number of beams for beam search.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
)
parser.add_argument(
"--no-streaming",
dest="streaming",
action="store_false",
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=None,
help="Maximum number of tokens to generate (for auto-regressive models).",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=2,
help="Number of warm-up steps to run before launching the timed runs.",
)
parser.add_argument(
"--user_prompt",
type=str,
default="Transcribe the audio clip into text.",
help="User prompt string.",
)
args = parser.parse_args()
parser.set_defaults(streaming=False)
main(args)