fn main_loop()

in crates/ratchet-models/src/whisper/task.rs [82:128]


    fn main_loop(
        &self,
        decoder: &mut WhisperDecoder,
        audio_ctx: Tensor,
        callback: &Option<impl Fn(StreamedSegment)>,
    ) -> Result<Vec<i32>, DecodeError> {
        use ratchet::DType;

        let mut tokens = self.get_initial_tokens();
        let sliced_vocab_size = self.tokenizer.vocab_size();
        let device = audio_ctx.device().clone();
        let mut timestamps_seen = 0;

        for _ in 0..self.sample_len {
            let input = if tokens.len() > self.initial_tokens_len.unwrap() {
                &tokens[tokens.len() - 1..]
            } else {
                &tokens
            };
            let input_t = Tensor::from_data(input, shape![1, input.len()], device.clone());

            let logits = decoder
                .schedule([audio_ctx.clone(), input_t])?
                .cast(DType::F32)?
                .resolve()?;
            decoder.cache_mut().update(input.len());

            let cpu_logits = logits.to(&Device::CPU)?;
            let mut logits = Self::slice_logits(cpu_logits, sliced_vocab_size);
            let token_t = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], Device::CPU);
            for m in &self.logit_mutators {
                logits = m.apply(logits, &self.tokenizer, Some(&token_t))?;
            }

            let (_, new_tokens, completed) = GreedySampler::sample(tokens, logits)?;

            if let Some(ref cb) = callback {
                self.handle_callback(&self.tokenizer, &new_tokens, &mut timestamps_seen, cb);
            }

            tokens = new_tokens;
            if completed {
                break;
            }
        }
        Ok(tokens)
    }