whisperx-model-fetch/download_whisperx_models.py (100 lines of code) (raw):

# This script is used in the amigo whisperx role - see https://github.com/guardian/amigo/pull/1607 - it enables running # whisperx offline, by pre-downloading all the required models to the AMI. It exists in this repo so that we can make # changes as new models are released without having to modify amigo itself import torchaudio from pyannote.audio import Pipeline import sys import huggingface_hub import typer # ASR Models # Should be kept in sync with https://github.com/m-bain/whisperX/blob/main/whisperx/asr.py DEFAULT_ALIGN_MODELS_TORCH = { "en": "WAV2VEC2_ASR_BASE_960H", "fr": "VOXPOPULI_ASR_BASE_10K_FR", "de": "VOXPOPULI_ASR_BASE_10K_DE", "es": "VOXPOPULI_ASR_BASE_10K_ES", "it": "VOXPOPULI_ASR_BASE_10K_IT", } DEFAULT_ALIGN_MODELS_HF = { "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", "vi": 'nguyenvulebinh/wav2vec2-base-vi', "ko": "kresnik/wav2vec2-large-xlsr-korean", "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", "te": "anuragshas/wav2vec2-large-xlsr-53-telugu", "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", "ca": "softcatala/wav2vec2-large-xlsr-catala", "ml": "gvs/wav2vec2-large-xlsr-malayalam", "no": "NbAiLab/nb-wav2vec2-1b-bokmaal", "nn": "NbAiLab/nb-wav2vec2-300m-nynorsk", } def download_torch_align_models(): for lang, model_name in DEFAULT_ALIGN_MODELS_TORCH.items(): print(f"Downloading {model_name} for {lang}") bundle = torchaudio.pipelines.__dict__[model_name] bundle.get_model() print(f"Downloaded {model_name} for {lang}") def download_huggingface_align_models(): for lang, model_name in DEFAULT_ALIGN_MODELS_HF.items(): print(f"Downloading {model_name} for {lang}") huggingface_hub.snapshot_download(model_name) print(f"Downloaded {model_name} for {lang}") # Diarization - see https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py def download_diarization_models(auth_token): pyannote_model="pyannote/speaker-diarization-3.1" print(f"Downloading diarization models {pyannote_model}") Pipeline.from_pretrained(pyannote_model, use_auth_token=auth_token) # faster-whisper models ################ # Note - this section below is copied from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/utils.py # and then heavily simplified to only include the models we need ############### WHISPER_MODELS = { "tiny": "Systran/faster-whisper-tiny", "small": "Systran/faster-whisper-small", "medium": "Systran/faster-whisper-medium", "large": "Systran/faster-whisper-large-v3", } def download_model( model: str, ): """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. Args: model: Size of the model to download from https://huggingface.co/Systran (see https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/utils.py#L12 for full list) - here limited to tiny, small, medium, large. Returns: The path to the downloaded model. """ print(f"Downloading whisper model {model}") repo_id = WHISPER_MODELS.get(model) allow_patterns = [ "config.json", "preprocessor_config.json", "model.bin", "tokenizer.json", "vocabulary.*", ] kwargs = { "allow_patterns": allow_patterns, } return huggingface_hub.snapshot_download(repo_id, **kwargs) def download_all_whisper_models(): for model_name in WHISPER_MODELS.keys(): download_model(model_name) app = typer.Typer() @app.command() def main( whisper_models: bool = typer.Option(False, help="Download whisper models"), diarization_models: bool = typer.Option(False, help="Download diarization models"), torch_align_models: bool = typer.Option(False, help="Download torch align models"), huggingface_align_models: bool = typer.Option(False, help="Download huggingface align models"), huggingface_token: str = typer.Option("", help="Huggingface authentication token")): if whisper_models: download_all_whisper_models() if diarization_models: if not huggingface_token: print("Please provide a Huggingface authentication token (--huggingface-token <token>)") sys.exit(1) download_diarization_models(huggingface_token) if torch_align_models: download_torch_align_models() if huggingface_align_models: download_huggingface_align_models() if __name__ == "__main__": typer.run(main)