in TTS/parler_handler.py [0:0]
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
if self.device == "cuda":
torch.cuda.synchronize()
start_event.record()
if self.compile_mode:
pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
for pad_length in pad_lengths[::-1]:
model_kwargs = self.prepare_model_inputs(
"dummy prompt", max_length_prompt=pad_length, pad=True
)
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
logger.info(f"Warmed up length {pad_length} tokens!")
else:
model_kwargs = self.prepare_model_inputs("dummy prompt")
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)