docker_images/speechbrain/app/pipelines/audio_to_audio.py (46 lines of code) (raw):

from typing import List, Tuple import numpy as np import torch from app.common import ModelType, get_type from app.pipelines import Pipeline from speechbrain.inference import ( SepformerSeparation, SpectralMaskEnhancement, WaveformEnhancement, ) class AudioToAudioPipeline(Pipeline): def __init__(self, model_id: str): model_type = get_type(model_id) if model_type == ModelType.SEPFORMERSEPARATION: self.model = SepformerSeparation.from_hparams(source=model_id) self.type = "audio-source-separation" elif model_type == ModelType.SPECTRALMASKENHANCEMENT: self.model = SpectralMaskEnhancement.from_hparams(source=model_id) self.type = "speech-enhancement" elif model_type == ModelType.WAVEFORMENHANCEMENT: self.type = "speech-enhancement" self.model = WaveformEnhancement.from_hparams(source=model_id) else: raise ValueError(f"{model_type.value} is invalid for audio-to-audio") self.sampling_rate = self.model.hparams.sample_rate def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]: """ Args: inputs (:obj:`np.array`): The raw waveform of audio received. By default sampled at `self.sampling_rate`. The shape of this array is `T`, where `T` is the time axis Return: A :obj:`tuple` containing: - :obj:`np.array`: The return shape of the array must be `C'`x`T'` - a :obj:`int`: the sampling rate as an int in Hz. - a :obj:`List[str]`: the annotation for each out channel. This can be the name of the instruments for audio source separation or some annotation for speech enhancement. The length must be `C'`. """ if self.type == "speech-enhancement": return self.enhance(inputs) elif self.type == "audio-source-separation": return self.separate(inputs) else: return self.separate(inputs) def separate(self, inputs): mix = torch.from_numpy(inputs) est_sources = self.model.separate_batch(mix.unsqueeze(0)) est_sources = est_sources[0] # C x T est_sources = est_sources.transpose(1, 0) # normalize for loudness est_sources = est_sources / est_sources.abs().max(dim=1, keepdim=True).values n = est_sources.shape[0] labels = [f"label_{i}" for i in range(n)] return est_sources.numpy(), int(self.sampling_rate), labels def enhance(self, inputs: np.array): mix = torch.from_numpy(inputs) enhanced = self.model.enhance_batch(mix.unsqueeze(0)) # C x T labels = ["speech_enhanced"] return enhanced.numpy(), int(self.sampling_rate), labels