ctranslate2/run_eval.py (121 lines of code) (raw):
"""Run evaluation for ctranslate2 whisper models."""""
import argparse
import os
import time
import evaluate
from faster_whisper import WhisperModel
from tqdm import tqdm
from normalizer import data_utils
wer_metric = evaluate.load("wer")
def main(args) -> None:
"""Main function to run evaluation on a dataset."""
asr_model = WhisperModel(
model_size_or_path=args.model_id,
compute_type="float16",
device="cuda",
device_index=args.device
)
def benchmark(batch):
start_time = time.time()
segments, _ = asr_model.transcribe(batch["audio"]["array"], language="en")
outputs = [segment._asdict() for segment in segments]
batch["transcription_time_s"] = time.time() - start_time
batch["predictions"] = data_utils.normalizer("".join([segment["text"] for segment in outputs])).strip()
batch["references"] = batch["norm_text"]
return batch
if args.warmup_steps is not None:
dataset = data_utils.load_data(args)
dataset = data_utils.prepare_data(dataset)
if args.streaming:
warmup_dataset = dataset.take(args.warmup_steps)
else:
warmup_dataset = dataset.select(range(min(args.warmup_steps, len(dataset))))
warmup_dataset = iter(warmup_dataset.map(benchmark, remove_columns=["audio"]))
for _ in tqdm(warmup_dataset, desc="Warming up..."):
continue
dataset = data_utils.load_data(args)
if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
if args.streaming:
dataset = dataset.take(args.max_eval_samples)
else:
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
dataset = data_utils.prepare_data(dataset)
dataset = dataset.map(benchmark, remove_columns=["audio"])
all_results = {
"audio_length_s": [],
"transcription_time_s": [],
"predictions": [],
"references": [],
}
result_iter = iter(dataset)
for result in tqdm(result_iter, desc="Samples..."):
for key in all_results:
all_results[key].append(result[key])
# Write manifest results (WER and RTFX)
manifest_path = data_utils.write_manifest(
all_results["references"],
all_results["predictions"],
args.model_id,
args.dataset_path,
args.dataset,
args.split,
audio_length=all_results["audio_length_s"],
transcription_time=all_results["transcription_time_s"],
)
print("Results saved at path:", os.path.abspath(manifest_path))
wer = wer_metric.compute(
references=all_results["references"], predictions=all_results["predictions"]
)
wer = round(100 * wer, 2)
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
print("WER:", wer, "%", "RTFx:", rtfx)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
required=True,
help="Model identifier. Should be loadable with faster-whisper",
)
parser.add_argument(
'--dataset_path', type=str, default='esb/datasets', help='Dataset path. By default, it is `esb/datasets`'
)
parser.add_argument(
"--dataset",
type=str,
required=True,
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
"can be found at `https://huggingface.co/datasets/esb/datasets`"
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
)
parser.add_argument(
"--device",
type=int,
default=-1,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
)
parser.add_argument(
"--no-streaming",
dest='streaming',
action="store_false",
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=5,
help="Number of warm-up steps to run before launching the timed runs.",
)
args = parser.parse_args()
parser.set_defaults(streaming=False)
main(args)