in crates/ratchet-models/src/phi3/model.rs [339:408]
fn load_phi3() -> anyhow::Result<()> {
let _ = env_logger::builder().is_test(true).try_init();
let api = Api::new().unwrap();
let model_repo = api.model("FL33TW00D-HF/phi3".to_string());
let model_path = model_repo.get("phi3-mini-4k-f16.gguf").unwrap();
println!("MODEL PATH: {}", model_path.display());
let mut reader = std::io::BufReader::new(std::fs::File::open(model_path)?);
let device = Device::request_device(DeviceRequest::GPU)?;
let content = gguf::gguf::Header::read(&mut reader)?;
let mut model = Phi3::load(content, &mut reader, &device)?;
let tokenizer_repo = api.model("microsoft/Phi-3-mini-4k-instruct".to_string());
let tokenizer_path = tokenizer_repo.get("tokenizer.json").unwrap();
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
let MAX_TOKENS = 100;
let prompt = r#"<|user|>
How to explain Internet for a medieval knight?<|end|>
<|assistant|>"#;
let encoding = tokenizer.encode(prompt, true).unwrap();
let mut tokens = encoding
.get_ids()
.iter()
.map(|&x| x as i32)
.collect::<Vec<_>>();
tokens.insert(0, 1); //TODO: what is going on here with tokenizers?
let mut all_logits = vec![];
let mut all_tokens = tokens.clone();
let mut generated_cnt = tokens.len();
let start = std::time::Instant::now();
while tokens[tokens.len() - 1] != 32007 && generated_cnt < MAX_TOKENS {
let input = Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone());
let result = model.schedule(input)?.full()?.resolve()?;
let logits = result.to(&Device::CPU)?;
all_logits.push(logits.clone());
model.cache_mut().update(tokens.len());
tokens = logits
.to_ndarray_view::<f32>()
.map_axis(Axis(2), |row| row.argmax_skipnan().unwrap())
.iter()
.map(|&x| x as i32)
.collect::<Vec<_>>();
all_tokens.extend(tokens.clone());
generated_cnt += 1;
}
let elapsed = start.elapsed();
let u32_toks = all_tokens.iter().map(|&x| x as u32).collect::<Vec<_>>();
let generated = tokenizer.decode(&u32_toks, true).unwrap();
println!("We generated: \n{}\n", generated);
let ground_logits = ground_truth(prompt, MAX_TOKENS)?;
assert_eq!(all_logits.len(), ground_logits.len());
let all_equal =
ground_logits
.iter()
.zip(all_logits.iter())
.enumerate()
.all(|(i, (their, our))| {
print!("Checking: {}", i);
our.all_close(their, 1e-1, 1e-1).is_ok()
});
println!("All logits equal: {}", all_equal);
assert!(all_equal);
Ok(())
}