crates/ratchet-web/src/model.rs (357 lines of code) (raw):
use crate::db::*;
use futures::stream::TryStreamExt;
use futures::StreamExt;
use ratchet_hub::{Api, ApiBuilder, RepoType};
use ratchet_loader::gguf::gguf::{self, Header, TensorInfo};
use ratchet_models::moondream::{self, Moondream};
use ratchet_models::phi2;
use ratchet_models::phi2::Phi2;
use ratchet_models::phi3::{self, Phi3};
use ratchet_models::registry::{AvailableModels, PhiVariants, Quantization};
use ratchet_models::whisper::{transcribe::transcribe, transcript::StreamedSegment, Whisper};
use ratchet_models::TensorMap;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
#[derive(Debug)]
pub enum WebModel {
Whisper(Whisper),
Phi2(Phi2),
Phi3(Phi3),
Moondream(Moondream),
}
impl WebModel {
pub async fn run(&mut self, input: JsValue) -> Result<JsValue, JsValue> {
match self {
WebModel::Whisper(model) => {
let input: WhisperInputs = serde_wasm_bindgen::from_value(input)?;
let options = serde_wasm_bindgen::from_value(input.decode_options)?;
let callback = if !input.callback.is_null() {
let rs_callback = |decoded: StreamedSegment| {
let js_decoded = serde_wasm_bindgen::to_value(&decoded).unwrap();
let _ = input.callback.call1(&JsValue::NULL, &js_decoded);
};
Some(rs_callback)
} else {
None
};
let result = transcribe(model, input.audio, options, callback)
.await
.unwrap();
serde_wasm_bindgen::to_value(&result).map_err(|e| e.into())
}
WebModel::Phi2(model) => {
let input: PhiInputs = serde_wasm_bindgen::from_value(input)?;
let rs_callback = |output: String| {
let _ = input.callback.call1(&JsValue::NULL, &output.into());
};
let prompt = input.prompt;
let model_repo = ApiBuilder::from_hf("microsoft/phi-2", RepoType::Model).build();
let model_bytes = model_repo.get("tokenizer.json").await?;
let tokenizer = Tokenizer::from_bytes(model_bytes.to_vec()).unwrap();
phi2::generate(model, tokenizer, prompt, rs_callback)
.await
.unwrap();
Ok(JsValue::NULL)
}
WebModel::Phi3(model) => {
let input: PhiInputs = serde_wasm_bindgen::from_value(input)?;
let rs_callback = |output: String| {
let _ = input.callback.call1(&JsValue::NULL, &output.into());
};
let prompt = input.prompt;
let model_repo =
ApiBuilder::from_hf("microsoft/Phi-3-mini-4k-instruct", RepoType::Model)
.build();
let model_bytes = model_repo.get("tokenizer.json").await?;
let tokenizer = Tokenizer::from_bytes(model_bytes.to_vec()).unwrap();
phi3::generate(model, tokenizer, prompt, rs_callback)
.await
.unwrap();
Ok(JsValue::NULL)
}
WebModel::Moondream(model) => {
let input: MoondreamInputs = serde_wasm_bindgen::from_value(input)?;
let rs_callback = |output: String| {
let _ = input.callback.call1(&JsValue::NULL, &output.into());
};
let model_repo =
ApiBuilder::from_hf("tgestson/ratchet-moondream2", RepoType::Model).build();
let model_bytes = model_repo.get("tokenizer.json").await?;
let tokenizer = Tokenizer::from_bytes(model_bytes.to_vec()).unwrap();
moondream::generate(
model,
input.image_bytes,
input.question,
tokenizer,
rs_callback,
)
.await
.unwrap();
Ok(JsValue::NULL)
}
}
}
pub async fn from_stored(
model_record: ModelRecord,
tensor_map: TensorMap,
) -> Result<WebModel, anyhow::Error> {
let header = serde_wasm_bindgen::from_value::<Header>(model_record.header).unwrap();
match model_record.model {
AvailableModels::Whisper(variant) => {
let model = Whisper::from_web(header, tensor_map, variant).await?;
Ok(WebModel::Whisper(model))
}
AvailableModels::Phi(variant) => match variant {
PhiVariants::Phi2 => {
let model = Phi2::from_web(header, tensor_map).await?;
Ok(WebModel::Phi2(model))
}
PhiVariants::Phi3 => {
let model = Phi3::from_web(header, tensor_map).await?;
Ok(WebModel::Phi3(model))
}
},
AvailableModels::Moondream => {
let model = Moondream::from_web(header, tensor_map).await?;
Ok(WebModel::Moondream(model))
}
_ => Err(anyhow::anyhow!("Unknown model type")),
}
}
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct WhisperInputs {
pub audio: Vec<f32>,
#[serde(with = "serde_wasm_bindgen::preserve")]
pub decode_options: JsValue,
#[serde(with = "serde_wasm_bindgen::preserve")]
pub callback: js_sys::Function,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct PhiInputs {
pub prompt: String,
#[serde(with = "serde_wasm_bindgen::preserve")]
pub callback: js_sys::Function,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct MoondreamInputs {
pub question: String,
pub image_bytes: Vec<u8>,
#[serde(with = "serde_wasm_bindgen::preserve")]
pub callback: js_sys::Function,
}
#[wasm_bindgen]
#[derive(Debug)]
pub struct Model {
inner: WebModel,
}
#[wasm_bindgen]
impl Model {
/// The main JS entrypoint into the library.
///
/// Loads a model with the provided ID.
/// This key should be an enum of supported models.
#[wasm_bindgen]
pub async fn load(
model: AvailableModels,
quantization: Quantization,
progress: &js_sys::Function,
) -> Result<Model, JsValue> {
let model_key = ModelKey::from_available(&model, quantization);
let model_repo = ApiBuilder::from_hf(&model_key.repo_id(), RepoType::Model).build();
let webModel = Self::load_inner(model, &model_repo, model_key, progress).await?;
Ok(Model { inner: webModel })
}
#[wasm_bindgen]
pub async fn load_custom(
endpoint: String,
model: AvailableModels,
quantization: Quantization,
progress: &js_sys::Function,
) -> Result<Model, JsValue> {
let model_key = ModelKey::from_available(&model, quantization);
let model_repo = ApiBuilder::from_custom(endpoint).build();
let webModel = Self::load_inner(model, &model_repo, model_key, progress).await?;
Ok(Model { inner: webModel })
}
async fn load_inner(
model: AvailableModels,
model_repo: &Api,
model_key: ModelKey,
progress: &js_sys::Function,
) -> Result<WebModel, JsValue> {
let db = RatchetDB::open().await.map_err(|e| {
let e: JsError = e.into();
Into::<JsValue>::into(e)
})?;
log::warn!("Loading model: {:?}", model_key);
if let None = db.get_model(&model_key).await.map_err(|e| {
let e: JsError = e.into();
Into::<JsValue>::into(e)
})? {
let header: gguf::Header = serde_wasm_bindgen::from_value(
model_repo.fetch_gguf_header(&model_key.model_id()).await?,
)?;
Self::fetch_tensors(&db, &model_repo, &header, model_key.clone(), progress).await?;
let model_record = ModelRecord::new(model_key.clone(), model.clone(), header);
db.put_model(&model_key, model_record).await.map_err(|e| {
let e: JsError = e.into();
Into::<JsValue>::into(e)
})?;
};
let model_record = db.get_model(&model_key).await.unwrap().unwrap();
let tensors = db.get_tensors(&model_key).await.unwrap();
Ok(WebModel::from_stored(model_record, tensors).await.unwrap())
}
/// User-facing method to run the model.
///
/// Untyped input is required unfortunately.
pub async fn run(&mut self, input: JsValue) -> Result<JsValue, JsValue> {
self.inner.run(input).await
}
async fn fetch_tensors(
db: &RatchetDB,
model_repo: &Api,
header: &Header,
model_key: ModelKey,
progress: &js_sys::Function,
) -> Result<(), JsValue> {
let model_id = model_key.model_id();
let data_offset = header.tensor_data_offset;
let content_len = header
.tensor_infos
.values()
.fold(0, |acc, ti| acc + ti.size_in_bytes());
let mut tensor_infos: Vec<(String, TensorInfo)> =
header.tensor_infos.clone().into_iter().collect();
tensor_infos.sort_by(|(_, a), (_, b)| b.size_in_bytes().cmp(&a.size_in_bytes()));
let tensor_stream = futures::stream::iter(tensor_infos);
let mut total_progress = 0.0;
tensor_stream
.map(|(name, info): (String, TensorInfo)| {
let model_id = model_id.clone();
let model_key = model_key.clone();
async move {
let range = info.byte_range(data_offset);
let bytes = model_repo
.fetch_range(&model_id, range.start, range.end)
.await
.unwrap();
let length = bytes.length();
let record =
TensorRecord::new(name.clone().to_string(), model_key.clone(), bytes);
db.put_tensor(record).await.map_err(|e| {
let e: JsError = e.into();
Into::<JsValue>::into(e)
});
length
}
})
.buffer_unordered(6)
.map(|num_bytes| {
let req_progress = (num_bytes as f64) / (content_len as f64) * 100.0;
total_progress += req_progress;
let _ = progress.call1(&JsValue::NULL, &total_progress.into());
})
.collect::<()>()
.await;
Ok(())
}
}
#[cfg(all(test, target_arch = "wasm32"))]
mod tests {
use super::*;
use ratchet_hub::{ApiBuilder, RepoType};
use ratchet_models::registry::PhiVariants;
use ratchet_models::registry::WhisperVariants;
use ratchet_models::whisper::options::DecodingOptionsBuilder;
use tokenizers::Tokenizer;
use wasm_bindgen_test::*;
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
fn log_init() {
console_error_panic_hook::set_once();
let logger = fern::Dispatch::new()
.format(|out, message, record| {
out.finish(format_args!(
"{}[{}][{}] {}",
chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"),
record.target(),
record.level(),
message
))
})
.level_for("tokenizers", log::LevelFilter::Off)
.level(log::LevelFilter::Warn)
.chain(fern::Output::call(console_log::log))
.apply();
match logger {
Ok(_) => log::info!("Logging initialized."),
Err(error) => eprintln!("Error initializing logging: {:?}", error),
}
}
fn load_sample(bytes: &[u8]) -> Vec<f32> {
let mut reader = hound::WavReader::new(std::io::Cursor::new(bytes)).unwrap();
reader
.samples::<i16>()
.map(|x| x.unwrap() as f32 / 32768.0)
.collect::<Vec<_>>()
}
#[wasm_bindgen_test]
async fn whisper_browser() -> Result<(), JsValue> {
log_init();
let download_cb: Closure<dyn Fn(f64)> = Closure::new(|p| {
log::info!("Provided closure got progress: {}", p);
});
let js_cb: &js_sys::Function = download_cb.as_ref().unchecked_ref();
let mut model = Model::load(
AvailableModels::Whisper(WhisperVariants::Base),
Quantization::F16,
js_cb,
)
.await
.unwrap();
let data_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build();
let audio_bytes = data_repo.get("mm0.wav").await?;
let sample = load_sample(&audio_bytes.to_vec());
let decode_options = DecodingOptionsBuilder::default().build();
let cb: Closure<dyn Fn(JsValue)> = Closure::new(|s| {
log::info!("GENERATED SEGMENT: {:?}", s);
});
let js_cb: &js_sys::Function = cb.as_ref().unchecked_ref();
let input = WhisperInputs {
audio: sample,
decode_options,
callback: js_cb.clone(),
};
let input = serde_wasm_bindgen::to_value(&input).unwrap();
let result = model.run(input).await.unwrap();
log::warn!("Result: {:?}", result);
Ok(())
}
#[wasm_bindgen_test]
async fn whisper_browser_custom() -> Result<(), JsValue> {
log_init();
let download_cb: Closure<dyn Fn(f64)> = Closure::new(|p| {
log::info!("Provided closure got progress: {}", p);
});
let js_cb: &js_sys::Function = download_cb.as_ref().unchecked_ref();
let mut model = Model::load_custom(
"https://huggingface.co/FL33TW00D-HF/whisper-base/resolve/main".to_string(),
AvailableModels::Whisper(WhisperVariants::Base),
Quantization::F16,
js_cb,
)
.await
.unwrap();
let data_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build();
let audio_bytes = data_repo.get("mm0.wav").await?;
let sample = load_sample(&audio_bytes.to_vec());
let decode_options = DecodingOptionsBuilder::default().build();
let cb: Closure<dyn Fn(JsValue)> = Closure::new(|s| {
log::info!("GENERATED SEGMENT: {:?}", s);
});
let js_cb: &js_sys::Function = cb.as_ref().unchecked_ref();
let input = WhisperInputs {
audio: sample,
decode_options,
callback: js_cb.clone(),
};
let input = serde_wasm_bindgen::to_value(&input).unwrap();
let result = model.run(input).await.unwrap();
log::warn!("Result: {:?}", result);
Ok(())
}
}