in STT/whisper_stt_handler.py [0:0]
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype,
device=self.device,
)
if self.compile_mode not in (None, "default"):
# generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
# hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs[
"max_new_tokens"
], # Yes, assign max_new_tokens to min_new_tokens
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
else:
warmup_gen_kwargs = self.gen_kwargs
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
_ = self.model.generate(dummy_input, **warmup_gen_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)