florence2-webgpu/src/worker.js (109 lines of code) (raw):
import {
Florence2ForConditionalGeneration,
AutoProcessor,
AutoTokenizer,
RawImage,
full,
} from "@huggingface/transformers";
async function hasFp16() {
try {
const adapter = await navigator.gpu.requestAdapter();
return adapter.features.has("shader-f16");
} catch (e) {
return false;
}
}
/**
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
*/
class Florence2Singleton {
static model_id = "onnx-community/Florence-2-base-ft";
static async getInstance(progress_callback = null) {
this.processor ??= AutoProcessor.from_pretrained(this.model_id);
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id);
this.supports_fp16 ??= await hasFp16();
this.model ??= Florence2ForConditionalGeneration.from_pretrained(
this.model_id,
{
dtype: {
embed_tokens: this.supports_fp16 ? "fp16" : "fp32",
vision_encoder: this.supports_fp16 ? "fp16" : "fp32",
encoder_model: "q4", // or 'fp16' or 'fp32'
decoder_model_merged: "q4", // or 'fp16' or 'fp32'
},
device: "webgpu",
progress_callback,
},
);
return Promise.all([this.model, this.tokenizer, this.processor]);
}
}
async function load() {
self.postMessage({
status: "loading",
data: "Loading model...",
});
// Load the pipeline and save it for future use.
const [model, tokenizer, processor] = await Florence2Singleton.getInstance(
(x) => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
},
);
self.postMessage({
status: "loading",
data: "Compiling shaders and warming up model...",
});
// Dummy text and vision inputs
const text_inputs = tokenizer("a");
const pixel_values = full([1, 3, 768, 768], 0.0);
// Run model with dummy input to compile shaders
await model.generate({
...text_inputs,
pixel_values,
max_new_tokens: 1,
});
self.postMessage({ status: "ready" });
}
const TASKS_WITH_INPUTS = ["<CAPTION_TO_PHRASE_GROUNDING>"];
let vision_inputs;
let image_size;
async function run({ text, url, task }) {
const [model, tokenizer, processor] = await Florence2Singleton.getInstance();
// Read and preprocess image
const start = performance.now();
if (!vision_inputs) {
// Cache vision inputs when possible
const image = await RawImage.fromURL(url);
image_size = image.size;
vision_inputs = await processor(image);
}
let user_input = task;
if (TASKS_WITH_INPUTS.includes(task) && text) {
user_input += text;
}
const prompts = processor.construct_prompts(user_input);
const text_inputs = tokenizer(prompts);
// Generate text
const generated_ids = await model.generate({
...text_inputs,
...vision_inputs,
max_new_tokens: 128,
num_beams: 1,
do_sample: false,
});
// Decode generated text
const generated_text = tokenizer.batch_decode(generated_ids, {
skip_special_tokens: false,
})[0];
// Post-process the generated text
const result = processor.post_process_generation(
generated_text,
task,
image_size,
);
const end = performance.now();
self.postMessage({ status: "complete", result, time: end - start });
}
// Listen for messages from the main thread
self.addEventListener("message", async (e) => {
const { type, data } = e.data;
switch (type) {
case "load":
load();
break;
case "run":
run(data);
break;
case "reset":
vision_inputs = image_size = null;
break;
}
});