janus-webgpu/src/worker.js (220 lines of code) (raw):
import {
AutoProcessor,
MultiModalityCausalLM,
BaseStreamer,
TextStreamer,
InterruptableStoppingCriteria,
} from "@huggingface/transformers";
// Define constants
const IMAGE_GENERATION_COMMAND_PREFIX = "/imagine ";
const MAX_NEW_TEXT_TOKENS = 1024;
/**
* Helper function to perform WebGPU feature detection
*/
let fp16_supported = false;
async function check() {
try {
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error("WebGPU is not supported (no adapter found)");
}
fp16_supported = adapter.features.has("shader-f16");
self.postMessage({
status: "success",
data: fp16_supported,
});
} catch (e) {
self.postMessage({
status: "error",
data: e.toString(),
});
}
}
/**
* This class uses the Singleton pattern to enable lazy-loading of the pipeline
*/
class ImageGenerationPipeline {
static model_id = "onnx-community/Janus-1.3B-ONNX";
static async getInstance(progress_callback = null) {
this.processor ??= AutoProcessor.from_pretrained(this.model_id, {
progress_callback,
});
this.model ??= MultiModalityCausalLM.from_pretrained(this.model_id, {
dtype: fp16_supported
? {
prepare_inputs_embeds: "q4",
language_model: "q4f16",
lm_head: "fp16",
gen_head: "fp16",
gen_img_embeds: "fp16",
image_decode: "fp32",
}
: {
prepare_inputs_embeds: "fp32",
language_model: "q4",
lm_head: "fp32",
gen_head: "fp32",
gen_img_embeds: "fp32",
image_decode: "fp32",
},
device: {
prepare_inputs_embeds: "wasm", // TODO use "webgpu" when bug is fixed
language_model: "webgpu",
lm_head: "webgpu",
gen_head: "webgpu",
gen_img_embeds: "webgpu",
image_decode: "webgpu",
},
progress_callback,
});
return Promise.all([this.processor, this.model]);
}
}
class ProgressStreamer extends BaseStreamer {
constructor(total, on_progress) {
super();
this.total = total;
this.on_progress = on_progress;
this.count = null;
this.start_time = null;
}
put(value) {
if (this.count === null) {
// Ignore the first batch of tokens (prompt)
this.count = 0;
this.start_time = performance.now();
return;
}
const progress = ++this.count / this.total;
this.on_progress({
count: this.count,
total: this.total,
progress,
time: performance.now() - this.start_time,
});
}
end() {
/* no nothing */
}
}
const stopping_criteria = new InterruptableStoppingCriteria();
async function generate(messages) {
// For this demo, we only respond to the last message
const message = messages.at(-1);
// Tell the main thread we are starting
self.postMessage({ status: "start" });
// Load the pipeline
const [processor, model] = await ImageGenerationPipeline.getInstance();
// Determine if the user wants to generate an image or text
if (message.content.startsWith(IMAGE_GENERATION_COMMAND_PREFIX)) {
const text = message.content.replace(IMAGE_GENERATION_COMMAND_PREFIX, "");
const conversation = [
{
role: "User", // uses title case
content: text,
},
];
const inputs = await processor(conversation, {
chat_template: "text_to_image",
});
const callback_function = (output) => {
self.postMessage({
status: "image-update",
...output,
});
};
const num_image_tokens = processor.num_image_tokens;
const streamer = new ProgressStreamer(num_image_tokens, callback_function);
const outputs = await model.generate_images({
...inputs,
min_new_tokens: num_image_tokens,
max_new_tokens: num_image_tokens,
do_sample: true,
streamer,
});
const blob = await outputs[0].toBlob();
// Send the output back to the main thread
self.postMessage({
status: "image-update",
blob,
});
} else {
const inputs = await processor(
message.image
? [
{
role: "User",
content: "<image_placeholder>\n" + message.content,
images: [message.image],
},
]
: [
{
role: "System",
content:
"You are a helpful assistant. Answer the user's questions in a concise manner.",
},
{
role: "User",
content: message.content,
},
],
);
let startTime;
let numTokens = 0;
let tps;
const token_callback_function = () => {
startTime ??= performance.now();
if (numTokens++ > 0) {
tps = (numTokens / (performance.now() - startTime)) * 1000;
}
};
const callback_function = (output) => {
self.postMessage({
status: "text-update",
output,
tps,
numTokens,
});
};
const streamer = new TextStreamer(processor.tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function,
token_callback_function,
});
// Generate response
const outputs = await model.generate({
...inputs,
max_new_tokens: MAX_NEW_TEXT_TOKENS,
do_sample: false,
streamer,
stopping_criteria,
});
}
// Tell the main thread we are done
self.postMessage({
status: "complete",
});
}
async function load() {
self.postMessage({
status: "loading",
data: "Loading model...",
});
// Load the pipeline and save it for future use.
const [processor, model] = await ImageGenerationPipeline.getInstance((x) => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
self.postMessage({ status: "ready" });
}
// Listen for messages from the main thread
self.addEventListener("message", async (e) => {
const { type, data } = e.data;
switch (type) {
case "check":
check();
break;
case "load":
load();
break;
case "generate":
stopping_criteria.reset();
generate(data);
break;
case "interrupt":
stopping_criteria.interrupt();
break;
case "reset":
stopping_criteria.reset();
break;
}
});