in crates/ratchet-models/src/phi2/model.rs [290:344]
fn load_phi2() -> anyhow::Result<()> {
let _ = env_logger::builder().is_test(true).try_init();
let api = Api::new().unwrap();
let model_repo = api.model("FL33TW00D-HF/phi2".to_string());
let model_path = model_repo.get("phi2-f16.gguf").unwrap();
println!("MODEL PATH: {}", model_path.display());
let mut reader = std::io::BufReader::new(std::fs::File::open(model_path)?);
let device = Device::request_device(DeviceRequest::GPU)?;
let content = gguf::gguf::Header::read(&mut reader)?;
let mut model = Phi2::load(content, &mut reader, &device)?;
let tokenizer_repo = api.model("microsoft/phi-2".to_string());
let tokenizer_path = tokenizer_repo.get("tokenizer.json").unwrap();
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
let prompt = "def print_prime(n):";
print!("{}", prompt);
let encoding = tokenizer.encode(prompt, true).unwrap();
let mut tokens = encoding
.get_ids()
.iter()
.map(|&x| x as i32)
.collect::<Vec<_>>();
let mut all_logits = vec![];
let mut all_tokens = tokens.clone();
let mut loop_cnt = 0;
while tokens[tokens.len() - 1] != 50256 && loop_cnt < 13 {
let input = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone());
let result = model.schedule(input)?.resolve()?;
let logits = result.to(&Device::CPU)?;
all_logits.push(logits.clone());
model.cache_mut().update(tokens.len());
tokens = logits
.to_ndarray_view::<f32>()
.map_axis(Axis(2), |row| row.argmax_skipnan().unwrap())
.iter()
.map(|&x| x as i32)
.collect::<Vec<_>>();
let u32_toks = tokens.iter().map(|&x| x as u32).collect::<Vec<_>>();
print!("{}", tokenizer.decode(&u32_toks, true).unwrap());
all_tokens.extend(tokens.clone());
loop_cnt += 1;
}
let ground_logits = ground_truth()?;
let all_equal = all_logits
.iter()
.zip(ground_logits.iter())
.all(|(our, their)| their.all_close(our, 1e-3, 1e-3).is_ok());
println!("All logits equal: {}", all_equal);
assert!(all_equal);
Ok(())
}