torchbenchmark/models/tts_angular/__init__.py (25 lines of code) (raw):
from ...util.model import BenchmarkModel
from torchbenchmark.tasks import SPEECH
import torch
from typing import Tuple
from .angular_tts_main import TTSModel, SYNTHETIC_DATA
class Model(BenchmarkModel):
task = SPEECH.SYNTHESIS
# Original train batch size: 64
# Source: https://github.com/mozilla/TTS/blob/master/TTS/speaker_encoder/config.json#L38
DEFAULT_TRAIN_BSIZE = 64
DEFAULT_EVAL_BSIZE = 64
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
self.model = TTSModel(device=self.device, batch_size=self.batch_size)
self.model.model.to(self.device)
def get_module(self):
return self.model.model, [SYNTHETIC_DATA[0], ]
def set_module(self, new_model):
self.model.model = new_model
def set_train(self):
self.model.model.train()
def train(self, niter=1):
# the training process is not patched to use scripted models
self.model.train(niter)
def eval(self, niter=1) -> Tuple[torch.Tensor]:
for _ in range(niter):
out = self.model.eval()
return (out, )