docker_images/k2/app/common.py (125 lines of code) (raw):
import functools
import json
from typing import List, Optional, Union
import k2
import kaldifeat
import sentencepiece as spm
import torch
from huggingface_hub import HfApi, hf_hub_download
from sherpa import RnntConformerModel
from .decode import (
run_model_and_do_greedy_search,
run_model_and_do_modified_beam_search,
)
def get_hfconfig(model_id, config_name="hf_demo"):
info = HfApi().model_info(repo_id=model_id)
config_file = hf_hub_download(model_id, filename="config.json")
with open(config_file) as config:
info.config = json.load(config)
if info.config and config_name is not None:
if config_name in info.config:
return info.config[config_name]
else:
raise ValueError("Config section " + config_name + " not found")
else:
return info
def model_from_hfconfig(hf_repo, hf_config):
nn_model_filename = hf_hub_download(hf_repo, hf_config["nn_model_filename"])
token_filename = (
hf_hub_download(hf_repo, hf_config["token_filename"])
if "token_filename" in hf_config
else None
)
bpe_model_filename = (
hf_hub_download(hf_repo, hf_config["bpe_model_filename"])
if "bpe_model_filename" in hf_config
else None
)
decoding_method = hf_config.get("decoding_method", "greedy_search")
sample_rate = hf_config.get("sample_rate", 16000)
num_active_paths = hf_config.get("num_active_paths", 4)
assert decoding_method in ("greedy_search", "modified_beam_search"), decoding_method
if decoding_method == "modified_beam_search":
assert num_active_paths >= 1, num_active_paths
assert bpe_model_filename is not None or token_filename is not None
if bpe_model_filename:
assert token_filename is None
if token_filename:
assert bpe_model_filename is None
return OfflineAsr(
nn_model_filename,
bpe_model_filename,
token_filename,
decoding_method,
num_active_paths,
sample_rate,
)
def transcribe_batch_from_tensor(model, batch):
return model.decode_waves([batch])[0]
class OfflineAsr(object):
def __init__(
self,
nn_model_filename: str,
bpe_model_filename: Optional[str],
token_filename: Optional[str],
decoding_method: str,
num_active_paths: int,
sample_rate: int = 16000,
device: Union[str, torch.device] = "cpu",
):
"""
Args:
nn_model_filename:
Path to the torch script model.
bpe_model_filename:
Path to the BPE model. If it is None, you have to provide
`token_filename`.
token_filename:
Path to tokens.txt. If it is None, you have to provide
`bpe_model_filename`.
decoding_method:
The decoding method to use. Currently, only greedy_search and
modified_beam_search are implemented.
num_active_paths:
Used only when decoding_method is modified_beam_search.
It specifies number of active paths for each utterance. Due to
merging paths with identical token sequences, the actual number
may be less than "num_active_paths".
sample_rate:
Expected sample rate of the feature extractor.
device:
The device to use for computation.
"""
self.model = RnntConformerModel(
filename=nn_model_filename,
device=device,
optimize_for_inference=False,
)
if bpe_model_filename:
self.sp = spm.SentencePieceProcessor()
self.sp.load(bpe_model_filename)
else:
self.token_table = k2.SymbolTable.from_file(token_filename)
self.sample_rate = sample_rate
self.feature_extractor = self._build_feature_extractor(
sample_rate=sample_rate,
device=device,
)
assert decoding_method in (
"greedy_search",
"modified_beam_search",
), decoding_method
if decoding_method == "greedy_search":
nn_and_decoding_func = run_model_and_do_greedy_search
elif decoding_method == "modified_beam_search":
nn_and_decoding_func = functools.partial(
run_model_and_do_modified_beam_search,
num_active_paths=num_active_paths,
)
else:
raise ValueError(
f"Unsupported decoding_method: {decoding_method} "
"Please use greedy_search or modified_beam_search"
)
self.nn_and_decoding_func = nn_and_decoding_func
self.device = device
def _build_feature_extractor(
self,
sample_rate: int = 16000,
device: Union[str, torch.device] = "cpu",
) -> kaldifeat.OfflineFeature:
"""Build a fbank feature extractor for extracting features.
Args:
sample_rate:
Expected sample rate of the feature extractor.
device:
The device to use for computation.
Returns:
Return a fbank feature extractor.
"""
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
return fbank
def decode_waves(self, waves: List[torch.Tensor]) -> List[List[str]]:
"""
Args:
waves:
A list of 1-D torch.float32 tensors containing audio samples.
wavs[i] contains audio samples for the i-th utterance.
Note:
Whether it should be in the range [-32768, 32767] or be normalized
to [-1, 1] depends on which range you used for your training data.
For instance, if your training data used [-32768, 32767],
then the given waves have to contain samples in this range.
All models trained in icefall use the normalized range [-1, 1].
Returns:
Return a list of decoded results. `ans[i]` contains the decoded
results for `wavs[i]`.
"""
waves = [w.to(self.device) for w in waves]
features = self.feature_extractor(waves)
tokens = self.nn_and_decoding_func(self.model, features)
if hasattr(self, "sp"):
results = self.sp.decode(tokens)
else:
results = [[self.token_table[i] for i in hyp] for hyp in tokens]
results = ["".join(r) for r in results]
return results