in crates/ratchet-models/src/whisper/decoder.rs [306:384]
fn decoder_matches() -> anyhow::Result<()> {
log_init();
let api = Api::new().unwrap();
let model = api.model("FL33TW00D-HF/whisper-tiny".to_string());
let path = model.get("tiny_f32.gguf").unwrap();
let config_path = model.get("config.json").unwrap();
let config: Config = serde_json::from_slice(&std::fs::read(config_path).unwrap()).unwrap();
println!("MODEL LOADED FROM: {}", path.display());
let dataset = api.dataset("FL33TW00D-HF/ratchet-util".to_string());
let options = DecodingOptionsBuilder::new().build();
let hs_npy = dataset.get("jfk_tiny_encoder_hs.npy").unwrap();
let audio_path = dataset.get("jfk.wav").unwrap();
let tokenizer_repo = api.model("openai/whisper-tiny".to_string());
let tokenizer_path = tokenizer_repo.get("tokenizer.json").unwrap();
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
let mut reader = std::io::BufReader::new(std::fs::File::open(path).unwrap());
let header = gguf::Header::read(&mut reader).unwrap();
let device = Device::request_device(DeviceRequest::GPU).unwrap();
let audio_ctx = Tensor::read_npy::<f32, _>(hs_npy, &device)?
.cast(device.compute_precision())?
.resolve()?;
let mut decoder = WhisperDecoder::load(&header, &config, &mut reader, &device)?;
let mut tokens = vec![50258, 50259, 50359];
let mut all_tokens = tokens.clone();
let mut all_logits = vec![];
let start = std::time::Instant::now();
while tokens[tokens.len() - 1] != 50257 {
let token_t =
Tensor::from_data(tokens.clone(), shape![1, tokens.len()], device.clone());
let result = decoder
.schedule([audio_ctx.clone(), token_t])?
.resolve_debug()?;
let our_logits = result.to(&Device::CPU)?;
let nd_logits = our_logits.to_ndarray_view::<f32>();
println!("ND LOGITS: {:?}", nd_logits);
all_logits.push(Tensor::from(
nd_logits
.slice(s![.., .., ..tokenizer.get_vocab_size(true)])
.to_owned()
.into_dyn(),
));
let sliced = nd_logits
.slice(s![.., -1.., ..tokenizer.get_vocab_size(true)])
.remove_axis(Axis(1));
decoder.cache_mut().update(tokens.len());
tokens = sliced
.map_axis(Axis(1), |row| row.argmax_skipnan().unwrap())
.iter()
.map(|&x| x as i32)
.collect::<Vec<_>>();
println!("Token: {:?}", tokens);
all_tokens.extend(tokens.clone());
}
println!("Took: {:?}", start.elapsed());
let u32_tokens: Vec<_> = all_tokens.iter().map(|&x| x as u32).collect();
let decoded = tokenizer.decode(&u32_tokens, true).unwrap();
println!("All tokens: {:?}", all_tokens);
println!("Decoded: {}", decoded);
let ground_logits = ground_truth(&audio_path.to_string_lossy(), options)?;
let all_equal = all_logits
.iter()
.zip(ground_logits.iter())
.all(|(our, their)| their.all_close(our, 1e-4, 1e-4).is_ok());
assert!(all_equal);
Ok(())
}