fn load_phi3()

in crates/ratchet-models/src/phi3/model.rs [339:408]


    fn load_phi3() -> anyhow::Result<()> {
        let _ = env_logger::builder().is_test(true).try_init();
        let api = Api::new().unwrap();
        let model_repo = api.model("FL33TW00D-HF/phi3".to_string());
        let model_path = model_repo.get("phi3-mini-4k-f16.gguf").unwrap();
        println!("MODEL PATH: {}", model_path.display());

        let mut reader = std::io::BufReader::new(std::fs::File::open(model_path)?);
        let device = Device::request_device(DeviceRequest::GPU)?;
        let content = gguf::gguf::Header::read(&mut reader)?;
        let mut model = Phi3::load(content, &mut reader, &device)?;

        let tokenizer_repo = api.model("microsoft/Phi-3-mini-4k-instruct".to_string());
        let tokenizer_path = tokenizer_repo.get("tokenizer.json").unwrap();
        let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();

        let MAX_TOKENS = 100;
        let prompt = r#"<|user|>
How to explain Internet for a medieval knight?<|end|>
<|assistant|>"#;
        let encoding = tokenizer.encode(prompt, true).unwrap();
        let mut tokens = encoding
            .get_ids()
            .iter()
            .map(|&x| x as i32)
            .collect::<Vec<_>>();
        tokens.insert(0, 1); //TODO: what is going on here with tokenizers?
        let mut all_logits = vec![];
        let mut all_tokens = tokens.clone();
        let mut generated_cnt = tokens.len();
        let start = std::time::Instant::now();

        while tokens[tokens.len() - 1] != 32007 && generated_cnt < MAX_TOKENS {
            let input = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone());
            let result = model.schedule(input)?.full()?.resolve()?;
            let logits = result.to(&Device::CPU)?;
            all_logits.push(logits.clone());
            model.cache_mut().update(tokens.len());

            tokens = logits
                .to_ndarray_view::<f32>()
                .map_axis(Axis(2), |row| row.argmax_skipnan().unwrap())
                .iter()
                .map(|&x| x as i32)
                .collect::<Vec<_>>();
            all_tokens.extend(tokens.clone());
            generated_cnt += 1;
        }
        let elapsed = start.elapsed();
        let u32_toks = all_tokens.iter().map(|&x| x as u32).collect::<Vec<_>>();

        let generated = tokenizer.decode(&u32_toks, true).unwrap();
        println!("We generated: \n{}\n", generated);

        let ground_logits = ground_truth(prompt, MAX_TOKENS)?;
        assert_eq!(all_logits.len(), ground_logits.len());
        let all_equal =
            ground_logits
                .iter()
                .zip(all_logits.iter())
                .enumerate()
                .all(|(i, (their, our))| {
                    print!("Checking: {}", i);
                    our.all_close(their, 1e-1, 1e-1).is_ok()
                });

        println!("All logits equal: {}", all_equal);
        assert!(all_equal);
        Ok(())
    }