webgpu-embedding-benchmark/main.js (248 lines of code) (raw):

import { AutoModel, ones } from "@huggingface/transformers"; import Chart from "chart.js/auto"; // Throw an error if WebGPU is not supported if (!navigator.gpu) { const err = "WebGPU is not supported by this browser."; alert(err); throw Error(err); } // Reference the elements that we will need const ctx = document.getElementById("chart"); const batchSizes = document.getElementById("batch-sizes"); const xscale = document.getElementById("x-scale"); const yscale = document.getElementById("y-scale"); const sequenceLength = document.getElementById("sequence-length"); const modelID = document.getElementById("model-id"); const status = document.getElementById("status"); const start = document.getElementById("start"); const stop = document.getElementById("stop"); const tests = document.getElementsByClassName("tests"); // Benchmark settings const NUM_WARMUP_STEPS = 3; const MODEL_CACHE = new Map(); // Chart configuration const initChart = () => { const config = { type: "line", data: { labels: [], datasets: [], }, options: { responsive: true, maintainAspectRatio: false, plugins: { legend: { position: "top", }, }, scales: { x: { title: { display: true, text: "Batch size", }, min: 1, }, y: { title: { display: true, text: "Time (ms)", }, }, }, }, }; const chart = new Chart(ctx, config); return chart; }; let chart = initChart(); const toggleScale = (axis, enabled) => { chart.options.scales[axis].type = enabled ? "logarithmic" : "linear"; chart.update(); }; const getSelectedTests = () => { return [...tests].filter((x) => x.checked); }; const updateDatasets = () => { chart.data.datasets = getSelectedTests().map((test) => { const color = test.getAttribute("data-color"); return { label: test.value, data: [], borderColor: `rgba(${color}, 1)`, backgroundColor: `rgba(${color}, 0.5)`, }; }); chart.update(); }; updateDatasets(); [...tests].forEach((test) => test.addEventListener("change", updateDatasets)); xscale.addEventListener("change", () => toggleScale("x", xscale.checked)); yscale.addEventListener("change", () => toggleScale("y", yscale.checked)); const generateDummyInputs = (batch_size, seqLength) => { const inputs = ones([batch_size, seqLength]); const model_inputs = { input_ids: inputs, attention_mask: inputs, }; return model_inputs; }; let adapterInfo; let gpuHasFp16 = false; try { // Shouldn't fail since the WebGPU model has loaded successfully const adapter = await navigator.gpu.requestAdapter(); adapterInfo = await adapter.requestAdapterInfo(); gpuHasFp16 = adapter.features.has("shader-f16"); } catch (err) { adapterInfo = {}; } if (!gpuHasFp16) { const element = document.querySelector( '.tests[data-device="webgpu"][data-dtype="fp16"]', ); element.setAttribute("unsupported", true); element.disabled = true; element.title = "This device does not support fp16 on WebGPU"; } status.textContent = "Ready"; let interrupted = false; start.addEventListener("click", async () => { const validTests = [...tests].filter( (test) => !test.getAttribute("unsupported"), ); // Update UI start.disabled = true; stop.disabled = false; batchSizes.disabled = true; sequenceLength.disabled = true; modelID.disabled = true; validTests.forEach((test) => (test.disabled = true)); interrupted = false; // Get parameters const model_id = modelID.value; const batch_sizes = batchSizes.value .split(",") .map((x) => parseInt(x)) .filter((x) => x); const seqLength = parseInt(sequenceLength.value); const selectedTests = getSelectedTests().map((x) => ({ label: x.value, dtype: x.getAttribute("data-dtype"), device: x.getAttribute("data-device"), })); // Reset chart.destroy(); chart = initChart(); updateDatasets(); // NOTE: Models must be loaded sequentially (otherwise it will fail due to multiple calls to initWasm()) const testsToRun = new Map(); for (const test of selectedTests) { const { label, dtype, device, quantized } = test; const key = `${model_id}///${label}`; const cached = MODEL_CACHE.get(key); if (cached) { testsToRun.set(label, cached); continue; } status.textContent = "Loading model(s)..."; try { const model = await AutoModel.from_pretrained(model_id, { quantized, device, dtype, }); MODEL_CACHE.set(key, model); testsToRun.set(label, model); } catch (err) { status.textContent = err.message; alert(err.message); throw err; } } status.textContent = "Warming up..."; // Warm up: This is important for the WebGPU execution provider, which compiles the shaders on first load for (let i = 0; i < NUM_WARMUP_STEPS; ++i) { const model_inputs = generateDummyInputs(1, seqLength); for (const [label, model] of testsToRun) { await model(model_inputs); } } status.textContent = "Running benchmark..."; for (const batch_size of batch_sizes) { if (interrupted) break; const model_inputs = generateDummyInputs(batch_size, seqLength); const times = []; for (const [label, model] of testsToRun) { const start = performance.now(); await model(model_inputs); const end = performance.now(); times.push(end - start); } chart.data.labels.push(batch_size); for (let i = 0; i < times.length; ++i) { chart.data.datasets[i].data.push(times[i]); } chart.update(); } // Calculate max speedup: if (chart.data.labels.length === 0) return; const testNames = [...testsToRun.keys()]; const table = generateResultsTable( model_id, testNames, chart.data, seqLength, ); // Calculate slowest and fastest times let minMaxTimes = [Infinity, 0]; let minMaxIndices = [0, 0]; for (let i = 0; i < chart.data.datasets.length; i++) { const lastTime = chart.data.datasets[i].data.at(-1); if (lastTime < minMaxTimes[0]) { minMaxTimes[0] = lastTime; minMaxIndices[0] = i; } if (lastTime > minMaxTimes[1]) { minMaxTimes[1] = lastTime; minMaxIndices[1] = i; } } const speedup = minMaxTimes[1] / minMaxTimes[0]; const roundedSpeedup = speedup.toFixed(2); const params = new URLSearchParams({ title: `⚡ WebGPU Benchmark Results (${roundedSpeedup}x speedup)`, description: table.outerHTML, }); const paramsStr = params.toString(); status.innerHTML = `⚡ Done! ${testNames.at(minMaxIndices[0])} is <strong>${roundedSpeedup}x</strong> faster than ${testNames.at(minMaxIndices[1])}! ⚡<br><a href="https://huggingface.co/spaces/Xenova/webgpu-embedding-benchmark/discussions/new?${paramsStr}" target="_blank">Share results</a>`; start.disabled = false; stop.disabled = true; batchSizes.disabled = false; sequenceLength.disabled = false; modelID.disabled = false; validTests.forEach((test) => (test.disabled = false)); }); start.disabled = false; stop.addEventListener("click", () => { status.textContent = "Stopping..."; interrupted = true; stop.disabled = true; }); function generateResultsTable(model_id, testNames, data, sequence_length) { const datasets = data.datasets.map((d) => d.data); const batch_sizes = data.labels; const container = document.createElement("div"); const table = document.createElement("table"); const thead = table.createTHead(); const tbody = table.createTBody(); // Add header row const headerRow = thead.insertRow(); headerRow.insertCell().textContent = "Batch Size"; testNames.forEach((model) => { headerRow.insertCell().textContent = model; }); // Add data rows batch_sizes.forEach((batchSize, rowIndex) => { const row = tbody.insertRow(); row.insertCell().textContent = batchSize; datasets.forEach((dataset) => { row.insertCell().textContent = dataset[rowIndex].toFixed(2); }); }); container.appendChild(table); const createBulletPoint = (text) => { const li = document.createElement("li"); li.textContent = text; return li; }; // Add other information const info = document.createElement("ul"); info.appendChild(createBulletPoint(`Model: ${model_id}`)); info.appendChild(createBulletPoint(`Tests run: ${testNames.join(", ")}`)); info.appendChild(createBulletPoint(`Sequence length: ${sequence_length}`)); info.appendChild(createBulletPoint(`Browser: ${navigator.userAgent}`)); info.appendChild( createBulletPoint( `GPU: vendor=${adapterInfo.vendor}, architecture=${adapterInfo.architecture}, device=${adapterInfo.device}, description=${adapterInfo.description}`, ), ); container.appendChild(info); return container; }