docker_images/speechbrain/app/pipelines/audio_to_audio.py (46 lines of code) (raw):
from typing import List, Tuple
import numpy as np
import torch
from app.common import ModelType, get_type
from app.pipelines import Pipeline
from speechbrain.inference import (
SepformerSeparation,
SpectralMaskEnhancement,
WaveformEnhancement,
)
class AudioToAudioPipeline(Pipeline):
def __init__(self, model_id: str):
model_type = get_type(model_id)
if model_type == ModelType.SEPFORMERSEPARATION:
self.model = SepformerSeparation.from_hparams(source=model_id)
self.type = "audio-source-separation"
elif model_type == ModelType.SPECTRALMASKENHANCEMENT:
self.model = SpectralMaskEnhancement.from_hparams(source=model_id)
self.type = "speech-enhancement"
elif model_type == ModelType.WAVEFORMENHANCEMENT:
self.type = "speech-enhancement"
self.model = WaveformEnhancement.from_hparams(source=model_id)
else:
raise ValueError(f"{model_type.value} is invalid for audio-to-audio")
self.sampling_rate = self.model.hparams.sample_rate
def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default sampled at `self.sampling_rate`.
The shape of this array is `T`, where `T` is the time axis
Return:
A :obj:`tuple` containing:
- :obj:`np.array`:
The return shape of the array must be `C'`x`T'`
- a :obj:`int`: the sampling rate as an int in Hz.
- a :obj:`List[str]`: the annotation for each out channel.
This can be the name of the instruments for audio source separation
or some annotation for speech enhancement. The length must be `C'`.
"""
if self.type == "speech-enhancement":
return self.enhance(inputs)
elif self.type == "audio-source-separation":
return self.separate(inputs)
else:
return self.separate(inputs)
def separate(self, inputs):
mix = torch.from_numpy(inputs)
est_sources = self.model.separate_batch(mix.unsqueeze(0))
est_sources = est_sources[0]
# C x T
est_sources = est_sources.transpose(1, 0)
# normalize for loudness
est_sources = est_sources / est_sources.abs().max(dim=1, keepdim=True).values
n = est_sources.shape[0]
labels = [f"label_{i}" for i in range(n)]
return est_sources.numpy(), int(self.sampling_rate), labels
def enhance(self, inputs: np.array):
mix = torch.from_numpy(inputs)
enhanced = self.model.enhance_batch(mix.unsqueeze(0))
# C x T
labels = ["speech_enhanced"]
return enhanced.numpy(), int(self.sampling_rate), labels