api/run_eval.py (395 lines of code) (raw):

import argparse from typing import Optional import datasets import evaluate import soundfile as sf import tempfile import time import os import requests import itertools from tqdm import tqdm from dotenv import load_dotenv from io import BytesIO import assemblyai as aai import openai from elevenlabs.client import ElevenLabs from rev_ai import apiclient from rev_ai.models import CustomerUrlData from normalizer import data_utils import concurrent.futures from speechmatics.models import ConnectionSettings, BatchTranscriptionConfig, FetchData from speechmatics.batch_client import BatchClient from httpx import HTTPStatusError from requests_toolbelt import MultipartEncoder load_dotenv() def fetch_audio_urls(dataset_path, dataset, split, batch_size=100, max_retries=20): API_URL = "https://datasets-server.huggingface.co/rows" size_url = f"https://datasets-server.huggingface.co/size?dataset={dataset_path}&config={dataset}&split={split}" size_response = requests.get(size_url).json() total_rows = size_response["size"]["config"]["num_rows"] audio_urls = [] for offset in tqdm(range(0, total_rows, batch_size), desc="Fetching audio URLs"): params = { "dataset": dataset_path, "config": dataset, "split": split, "offset": offset, "length": min(batch_size, total_rows - offset), } retries = 0 while retries <= max_retries: try: headers = {} if os.environ.get("HF_TOKEN") is not None: headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" else: print("HF_TOKEN not set, might experience rate-limiting.") response = requests.get(API_URL, params=params) response.raise_for_status() data = response.json() yield from data["rows"] break except (requests.exceptions.RequestException, ValueError) as e: retries += 1 print( f"Error fetching data: {e}, retrying ({retries}/{max_retries})..." ) time.sleep(10) if retries >= max_retries: raise Exception("Max retries exceeded while fetching data.") def transcribe_with_retry( model_name: str, audio_file_path: Optional[str], sample: dict, max_retries=10, use_url=False, ): retries = 0 while retries <= max_retries: try: PREFIX = "speechmatics/" if model_name.startswith(PREFIX): api_key = os.getenv("SPEECHMATICS_API_KEY") if not api_key: raise ValueError( "SPEECHMATICS_API_KEY environment variable not set" ) settings = ConnectionSettings( url="https://asr.api.speechmatics.com/v2", auth_token=api_key ) with BatchClient(settings) as client: config = BatchTranscriptionConfig( language="en", enable_entities=True, operating_point=model_name[len(PREFIX) :], ) job_id = None audio_url = None try: if use_url: audio_url = sample["row"]["audio"][0]["src"] config.fetch_data = FetchData(url=audio_url) multipart_data = MultipartEncoder( fields={"config": config.as_config().encode("utf-8")} ) response = client.send_request( "POST", "jobs", data=multipart_data.to_string(), headers={"Content-Type": multipart_data.content_type}, ) job_id = response.json()["id"] else: job_id = client.submit_job(audio_file_path, config) transcript = client.wait_for_completion( job_id, transcription_format="txt" ) return transcript except HTTPStatusError as e: if e.response.status_code == 401: raise ValueError( "Invalid Speechmatics API credentials" ) from e elif e.response.status_code == 400: raise ValueError( f"Speechmatics API responded with 400 Bad request: {e.response.text}" ) raise e except Exception as e: if job_id is not None: status = client.check_job_status(job_id) if ( audio_url is not None and "job" in status and "errors" in status["job"] and isinstance(status["job"]["errors"], list) and len(status["job"]["errors"]) > 0 ): errors = status["job"]["errors"] if "message" in errors[-1] and "failed to fetch file" in errors[-1]["message"]: retries = max_retries + 1 raise Exception(f"could not fetch URL {audio_url}, not retrying") raise Exception( f"Speechmatics transcription failed: {str(e)}" ) from e elif model_name.startswith("assembly/"): aai.settings.api_key = os.getenv("ASSEMBLYAI_API_KEY") transcriber = aai.Transcriber() config = aai.TranscriptionConfig( speech_model=model_name.split("/")[1], language_code="en", ) if use_url: audio_url = sample["row"]["audio"][0]["src"] audio_duration = sample["row"]["audio_length_s"] if audio_duration < 0.160: print(f"Skipping audio duration {audio_duration}s") return "." transcript = transcriber.transcribe(audio_url, config=config) else: audio_duration = ( len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] ) if audio_duration < 0.160: print(f"Skipping audio duration {audio_duration}s") return "." transcript = transcriber.transcribe(audio_file_path, config=config) if transcript.status == aai.TranscriptStatus.error: raise Exception( f"AssemblyAI transcription error: {transcript.error}" ) return transcript.text elif model_name.startswith("openai/"): if use_url: response = requests.get(sample["row"]["audio"][0]["src"]) audio_data = BytesIO(response.content) response = openai.Audio.transcribe( model=model_name.split("/")[1], file=audio_data, response_format="text", language="en", temperature=0.0, ) else: with open(audio_file_path, "rb") as audio_file: response = openai.Audio.transcribe( model=model_name.split("/")[1], file=audio_file, response_format="text", language="en", temperature=0.0, ) return response.strip() elif model_name.startswith("elevenlabs/"): client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY")) if use_url: response = requests.get(sample["row"]["audio"][0]["src"]) audio_data = BytesIO(response.content) transcription = client.speech_to_text.convert( file=audio_data, model_id=model_name.split("/")[1], language_code="eng", tag_audio_events=True, ) else: with open(audio_file_path, "rb") as audio_file: transcription = client.speech_to_text.convert( file=audio_file, model_id=model_name.split("/")[1], language_code="eng", tag_audio_events=True, ) return transcription.text elif model_name.startswith("revai/"): access_token = os.getenv("REVAI_API_KEY") client = apiclient.RevAiAPIClient(access_token) if use_url: # Submit job with URL for Rev.ai job = client.submit_job_url( transcriber=model_name.split("/")[1], source_config=CustomerUrlData(sample["row"]["audio"][0]["src"]), metadata="benchmarking_job", ) else: # Submit job with local file job = client.submit_job_local_file( transcriber=model_name.split("/")[1], filename=audio_file_path, metadata="benchmarking_job", ) # Polling until job is done while True: job_details = client.get_job_details(job.id) if job_details.status.name in ["IN_PROGRESS", "TRANSCRIBING"]: time.sleep(0.1) continue elif job_details.status.name == "FAILED": raise Exception("RevAI transcription failed.") elif job_details.status.name == "TRANSCRIBED": break transcript_object = client.get_transcript_object(job.id) # Combine all words from all monologues transcript_text = [] for monologue in transcript_object.monologues: for element in monologue.elements: transcript_text.append(element.value) return "".join(transcript_text) if transcript_text else "" else: raise ValueError( "Invalid model prefix, must start with 'assembly/', 'openai/', 'elevenlabs/' or 'revai/'" ) except Exception as e: retries += 1 if retries > max_retries: raise e if not use_url: sf.write( audio_file_path, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV", ) delay = 1 print( f"API Error: {str(e)}. Retrying in {delay}s... (Attempt {retries}/{max_retries})" ) time.sleep(delay) def transcribe_dataset( dataset_path, dataset, split, model_name, use_url=False, max_samples=None, max_workers=4, ): if use_url: audio_rows = fetch_audio_urls(dataset_path, dataset, split) if max_samples: audio_rows = itertools.islice(audio_rows, max_samples) ds = audio_rows else: ds = datasets.load_dataset(dataset_path, dataset, split=split, streaming=False) ds = data_utils.prepare_data(ds) if max_samples: ds = ds.take(max_samples) results = { "references": [], "predictions": [], "audio_length_s": [], "transcription_time_s": [], } print(f"Transcribing with model: {model_name}") def process_sample(sample): if use_url: reference = sample["row"]["text"].strip() or " " audio_duration = sample["row"]["audio_length_s"] start = time.time() try: transcription = transcribe_with_retry( model_name, None, sample, use_url=True ) except Exception as e: print(f"Failed to transcribe after retries: {e}") return None else: reference = sample.get("norm_text", "").strip() or " " with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: sf.write( tmpfile.name, sample["audio"]["array"], sample["audio"]["sampling_rate"], format="WAV", ) tmp_path = tmpfile.name audio_duration = ( len(sample["audio"]["array"]) / sample["audio"]["sampling_rate"] ) start = time.time() try: transcription = transcribe_with_retry( model_name, tmp_path, sample, use_url=False ) except Exception as e: print(f"Failed to transcribe after retries: {e}") os.unlink(tmp_path) return None finally: if os.path.exists(tmp_path): os.unlink(tmp_path) else: print(f"File {tmp_path} does not exist") transcription_time = time.time() - start return reference, transcription, audio_duration, transcription_time with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_sample = { executor.submit(process_sample, sample): sample for sample in ds } for future in tqdm( concurrent.futures.as_completed(future_to_sample), total=len(future_to_sample), desc="Transcribing", ): result = future.result() if result: reference, transcription, audio_duration, transcription_time = result results["predictions"].append(transcription) results["references"].append(reference) results["audio_length_s"].append(audio_duration) results["transcription_time_s"].append(transcription_time) results["predictions"] = [ data_utils.normalizer(transcription) or " " for transcription in results["predictions"] ] results["references"] = [ data_utils.normalizer(reference) or " " for reference in results["references"] ] manifest_path = data_utils.write_manifest( results["references"], results["predictions"], model_name.replace("/", "-"), dataset_path, dataset, split, audio_length=results["audio_length_s"], transcription_time=results["transcription_time_s"], ) print("Results saved at path:", manifest_path) wer_metric = evaluate.load("wer") wer = wer_metric.compute( references=results["references"], predictions=results["predictions"] ) wer_percent = round(100 * wer, 2) rtfx = round( sum(results["audio_length_s"]) / sum(results["transcription_time_s"]), 2 ) print("WER:", wer_percent, "%") print("RTFx:", rtfx) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Unified Transcription Script with Concurrency" ) parser.add_argument("--dataset_path", required=True) parser.add_argument("--dataset", required=True) parser.add_argument("--split", default="test") parser.add_argument( "--model_name", required=True, help="Prefix model name with 'assembly/', 'openai/', 'elevenlabs/', 'revai/', or 'speechmatics/'", ) parser.add_argument("--max_samples", type=int, default=None) parser.add_argument( "--max_workers", type=int, default=300, help="Number of concurrent threads" ) parser.add_argument( "--use_url", action="store_true", help="Use URL-based audio fetching instead of datasets", ) args = parser.parse_args() transcribe_dataset( dataset_path=args.dataset_path, dataset=args.dataset, split=args.split, model_name=args.model_name, use_url=args.use_url, max_samples=args.max_samples, max_workers=args.max_workers, )