tensorrtllm/whisper_utils.py (192 lines of code) (raw):
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/openai/whisper
import os
import base64
from functools import lru_cache
from pathlib import Path
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
import tiktoken
Pathlike = Union[str, Path]
N_FFT = 400
HOP_LENGTH = 160
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
def get_tokenizer(name: str = "multilingual",
num_languages: int = 99,
tokenizer_dir: str = None):
if tokenizer_dir is None:
vocab_path = os.path.join(os.path.dirname(__file__),
f"assets/{name}.tiktoken")
else:
vocab_path = os.path.join(tokenizer_dir, f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None)
def mel_filters(device,
n_mels: int,
mel_filters_dir: str = None) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
if mel_filters_dir is None:
mel_filters_path = os.path.join(os.path.dirname(__file__), "assets",
"mel_filters.npz")
else:
mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz")
with np.load(mel_filters_path) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
return_duration: bool = False,
mel_filters_dir: str = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 and 128 are supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80 or 128, n_frames)
A Tensor that contains the Mel spectrogram
"""
assert torch.is_tensor(audio), f"Unsupported audio type: {type(audio)}"
if device is not None:
audio = audio.to(device)
if padding > 0:
# pad to N_SAMPLES
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio,
N_FFT,
HOP_LENGTH,
window=window,
return_complex=True)
magnitudes = stft[..., :-1].abs()**2
filters = mel_filters(audio.device, n_mels, mel_filters_dir)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
if return_duration:
return log_spec, duration
else:
return log_spec