in speechbrain/run_eval.py [0:0]
def main(args):
"""Run the evaluation script."""
if args.device == -1:
device = "cpu"
else:
device = f"cuda:{args.device}"
model = get_model(
args.source,
args.speechbrain_pretrained_class_name,
args.beam_size,
args.ctc_weight_decode,
device=device
)
def benchmark(batch):
# Load audio inputs
audios = [torch.from_numpy(sample["array"]) for sample in batch["audio"]]
minibatch_size = len(audios)
audios, audio_lens = batch_pad_right(audios)
audios = audios.to(device)
audio_lens = audio_lens.to(device)
start_time = time.time()
with torch.autocast(device_type="cuda"):
predictions, _ = model.transcribe_batch(audios, audio_lens)
runtime = time.time() - start_time
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
# normalize transcriptions with English normalizer
batch["predictions"] = [data_utils.normalizer(pred) for pred in predictions]
batch["references"] = batch["norm_text"]
return batch
if args.warmup_steps is not None:
dataset = data_utils.load_data(args)
dataset = data_utils.prepare_data(dataset)
num_warmup_samples = args.warmup_steps * args.batch_size
if args.streaming:
warmup_dataset = dataset.take(num_warmup_samples)
else:
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True))
for _ in tqdm(warmup_dataset, desc="Warming up..."):
continue
dataset = data_utils.load_data(args)
if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
if args.streaming:
dataset = dataset.take(args.max_eval_samples)
else:
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
dataset = data_utils.prepare_data(dataset)
dataset = dataset.map(
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
)
all_results = {
"audio_length_s": [],
"transcription_time_s": [],
"predictions": [],
"references": [],
}
result_iter = iter(dataset)
for result in tqdm(result_iter, desc="Samples..."):
for key in all_results:
all_results[key].append(result[key])
# Write manifest results (WER and RTFX)
manifest_path = data_utils.write_manifest(
all_results["references"],
all_results["predictions"],
args.source,
args.dataset_path,
args.dataset,
args.split,
audio_length=all_results["audio_length_s"],
transcription_time=all_results["transcription_time_s"],
)
print("Results saved at path:", os.path.abspath(manifest_path))
wer_metric = evaluate.load("wer")
wer = wer_metric.compute(
references=all_results["references"], predictions=all_results["predictions"]
)
wer = round(100 * wer, 2)
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
print("WER:", wer, "%", "RTFx:", rtfx)