fn decoder_matches()

in crates/ratchet-models/src/whisper/decoder.rs [306:384]


    fn decoder_matches() -> anyhow::Result<()> {
        log_init();
        let api = Api::new().unwrap();
        let model = api.model("FL33TW00D-HF/whisper-tiny".to_string());
        let path = model.get("tiny_f32.gguf").unwrap();
        let config_path = model.get("config.json").unwrap();
        let config: Config = serde_json::from_slice(&std::fs::read(config_path).unwrap()).unwrap();
        println!("MODEL LOADED FROM: {}", path.display());

        let dataset = api.dataset("FL33TW00D-HF/ratchet-util".to_string());
        let options = DecodingOptionsBuilder::new().build();
        let hs_npy = dataset.get("jfk_tiny_encoder_hs.npy").unwrap();
        let audio_path = dataset.get("jfk.wav").unwrap();

        let tokenizer_repo = api.model("openai/whisper-tiny".to_string());
        let tokenizer_path = tokenizer_repo.get("tokenizer.json").unwrap();
        let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();

        let mut reader = std::io::BufReader::new(std::fs::File::open(path).unwrap());
        let header = gguf::Header::read(&mut reader).unwrap();

        let device = Device::request_device(DeviceRequest::GPU).unwrap();

        let audio_ctx = Tensor::read_npy::<f32, _>(hs_npy, &device)?
            .cast(device.compute_precision())?
            .resolve()?;
        let mut decoder = WhisperDecoder::load(&header, &config, &mut reader, &device)?;

        let mut tokens = vec![50258, 50259, 50359];
        let mut all_tokens = tokens.clone();
        let mut all_logits = vec![];
        let start = std::time::Instant::now();
        while tokens[tokens.len() - 1] != 50257 {
            let token_t =
                Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone());
            let result = decoder
                .schedule([audio_ctx.clone(), token_t])?
                .resolve_debug()?;

            let our_logits = result.to(&Device::CPU)?;
            let nd_logits = our_logits.to_ndarray_view::<f32>();
            println!("ND LOGITS: {:?}", nd_logits);
            all_logits.push(Tensor::from(
                nd_logits
                    .slice(s![.., .., ..tokenizer.get_vocab_size(true)])
                    .to_owned()
                    .into_dyn(),
            ));

            let sliced = nd_logits
                .slice(s![.., -1.., ..tokenizer.get_vocab_size(true)])
                .remove_axis(Axis(1));
            decoder.cache_mut().update(tokens.len());

            tokens = sliced
                .map_axis(Axis(1), |row| row.argmax_skipnan().unwrap())
                .iter()
                .map(|&x| x as i32)
                .collect::<Vec<_>>();
            println!("Token: {:?}", tokens);
            all_tokens.extend(tokens.clone());
        }
        println!("Took: {:?}", start.elapsed());

        let u32_tokens: Vec<_> = all_tokens.iter().map(|&x| x as u32).collect();
        let decoded = tokenizer.decode(&u32_tokens, true).unwrap();
        println!("All tokens: {:?}", all_tokens);
        println!("Decoded: {}", decoded);

        let ground_logits = ground_truth(&audio_path.to_string_lossy(), options)?;
        let all_equal = all_logits
            .iter()
            .zip(ground_logits.iter())
            .all(|(our, their)| their.all_close(our, 1e-4, 1e-4).is_ok());

        assert!(all_equal);

        Ok(())
    }