fn load_phi2()

in crates/ratchet-models/src/phi2/model.rs [290:344]


    fn load_phi2() -> anyhow::Result<()> {
        let _ = env_logger::builder().is_test(true).try_init();
        let api = Api::new().unwrap();
        let model_repo = api.model("FL33TW00D-HF/phi2".to_string());
        let model_path = model_repo.get("phi2-f16.gguf").unwrap();
        println!("MODEL PATH: {}", model_path.display());

        let mut reader = std::io::BufReader::new(std::fs::File::open(model_path)?);
        let device = Device::request_device(DeviceRequest::GPU)?;
        let content = gguf::gguf::Header::read(&mut reader)?;
        let mut model = Phi2::load(content, &mut reader, &device)?;

        let tokenizer_repo = api.model("microsoft/phi-2".to_string());
        let tokenizer_path = tokenizer_repo.get("tokenizer.json").unwrap();
        let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();

        let prompt = "def print_prime(n):";
        print!("{}", prompt);
        let encoding = tokenizer.encode(prompt, true).unwrap();
        let mut tokens = encoding
            .get_ids()
            .iter()
            .map(|&x| x as i32)
            .collect::<Vec<_>>();
        let mut all_logits = vec![];
        let mut all_tokens = tokens.clone();
        let mut loop_cnt = 0;
        while tokens[tokens.len() - 1] != 50256 && loop_cnt < 13 {
            let input = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone());
            let result = model.schedule(input)?.resolve()?;
            let logits = result.to(&Device::CPU)?;
            all_logits.push(logits.clone());
            model.cache_mut().update(tokens.len());

            tokens = logits
                .to_ndarray_view::<f32>()
                .map_axis(Axis(2), |row| row.argmax_skipnan().unwrap())
                .iter()
                .map(|&x| x as i32)
                .collect::<Vec<_>>();
            let u32_toks = tokens.iter().map(|&x| x as u32).collect::<Vec<_>>();
            print!("{}", tokenizer.decode(&u32_toks, true).unwrap());
            all_tokens.extend(tokens.clone());
            loop_cnt += 1;
        }

        let ground_logits = ground_truth()?;
        let all_equal = all_logits
            .iter()
            .zip(ground_logits.iter())
            .all(|(our, their)| their.all_close(our, 1e-3, 1e-3).is_ok());
        println!("All logits equal: {}", all_equal);
        assert!(all_equal);
        Ok(())
    }