torchbenchmark/models/tts_angular/__init__.py (25 lines of code) (raw):

from ...util.model import BenchmarkModel from torchbenchmark.tasks import SPEECH import torch from typing import Tuple from .angular_tts_main import TTSModel, SYNTHETIC_DATA class Model(BenchmarkModel): task = SPEECH.SYNTHESIS # Original train batch size: 64 # Source: https://github.com/mozilla/TTS/blob/master/TTS/speaker_encoder/config.json#L38 DEFAULT_TRAIN_BSIZE = 64 DEFAULT_EVAL_BSIZE = 64 def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) self.model = TTSModel(device=self.device, batch_size=self.batch_size) self.model.model.to(self.device) def get_module(self): return self.model.model, [SYNTHETIC_DATA[0], ] def set_module(self, new_model): self.model.model = new_model def set_train(self): self.model.model.train() def train(self, niter=1): # the training process is not patched to use scripted models self.model.train(niter) def eval(self, niter=1) -> Tuple[torch.Tensor]: for _ in range(niter): out = self.model.eval() return (out, )