docker_images/speechbrain/app/pipelines/text2text_generation.py (17 lines of code) (raw):
from typing import Dict, List
from app.common import ModelType, get_type
from app.pipelines import Pipeline
from speechbrain.inference import GraphemeToPhoneme
POSTPROCESSING = {ModelType.GRAPHEMETOPHONEME: lambda output: "-".join(output)}
class TextToTextPipeline(Pipeline):
def __init__(self, model_id: str):
model_type = get_type(model_id)
if model_type == ModelType.GRAPHEMETOPHONEME:
self.model = GraphemeToPhoneme.from_hparams(source=model_id)
else:
raise ValueError(f"{model_type.value} is invalid for text-to-text")
self.post_process = POSTPROCESSING.get(model_type, lambda output: output)
def __call__(self, inputs: str) -> List[Dict[str, str]]:
"""
Args:
inputs (:obj:`str`):
The input text
Return:
A :obj:`list`:. The list contains a single item that is a dict {"text": the model output}
"""
output = self.model(inputs)
output = self.post_process(output)
return [{"generated_text": output}]