STT/moonshine_handler.py (51 lines of code) (raw):

import os os.environ['KERAS_BACKEND'] = 'torch' from time import perf_counter import moonshine import torch from baseHandler import BaseHandler from rich.console import Console import logging logger = logging.getLogger(__name__) console = Console() class MoonshineSTTHandler(BaseHandler): """ Handles the Speech To Text generation using a Moonshine model. """ def setup( self, model_name="moonshine/base", torch_dtype="float16", gen_kwargs={}, ): self.torch_dtype = getattr(torch, torch_dtype) self.gen_kwargs = gen_kwargs self.tokenizer = moonshine.load_tokenizer() self.model = moonshine.load_model(model_name) self.warmup() def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") n_steps = 2 dummy_input = torch.randn( (1, 16000), dtype=self.torch_dtype, ) if torch.cuda.is_available(): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() for _ in range(n_steps): _ = self.model.generate(dummy_input) if torch.cuda.is_available(): end_event.record() torch.cuda.synchronize() logger.info( f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" ) def process(self, spoken_prompt): logger.debug("infering moonshine...") global pipeline_start pipeline_start = perf_counter() pred_ids = self.model.generate(spoken_prompt[None, :]) pred_text = self.tokenizer.decode_batch(pred_ids)[0] logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") yield (pred_text, "en")