in crates/ratchet-models/src/whisper/task.rs [82:128]
fn main_loop(
&self,
decoder: &mut WhisperDecoder,
audio_ctx: Tensor,
callback: &Option<impl Fn(StreamedSegment)>,
) -> Result<Vec<i32>, DecodeError> {
use ratchet::DType;
let mut tokens = self.get_initial_tokens();
let sliced_vocab_size = self.tokenizer.vocab_size();
let device = audio_ctx.device().clone();
let mut timestamps_seen = 0;
for _ in 0..self.sample_len {
let input = if tokens.len() > self.initial_tokens_len.unwrap() {
&tokens[tokens.len() - 1..]
} else {
&tokens
};
let input_t = Tensor::from_data(input, shape![1, input.len()], device.clone());
let logits = decoder
.schedule([audio_ctx.clone(), input_t])?
.cast(DType::F32)?
.resolve()?;
decoder.cache_mut().update(input.len());
let cpu_logits = logits.to(&Device::CPU)?;
let mut logits = Self::slice_logits(cpu_logits, sliced_vocab_size);
let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], Device::CPU);
for m in &self.logit_mutators {
logits = m.apply(logits, &self.tokenizer, Some(&token_t))?;
}
let (_, new_tokens, completed) = GreedySampler::sample(tokens, logits)?;
if let Some(ref cb) = callback {
self.handle_callback(&self.tokenizer, &new_tokens, &mut timestamps_seen, cb);
}
tokens = new_tokens;
if completed {
break;
}
}
Ok(tokens)
}