packages/tasks-gen/scripts/generate-snippets-fixtures.ts (418 lines of code) (raw):
/*
* Generates Inference API snippets using @huggingface/tasks snippets.
*
* If used in test mode ("pnpm test"), it compares the generated snippets with the expected ones.
* If used in generation mode ("pnpm generate-snippets-fixtures"), it generates the expected snippets.
*
* Expected snippets are saved under ./snippets-fixtures and are meant to be versioned on GitHub.
* Each snippet is saved in a separate file placed under "./{test-name}/{index}.{client}.{language}":
* - test-name: the name of the test (e.g. "text-to-image", "conversational-llm", etc.)
* - index: the order of the snippet in the array of snippets (0 if not an array)
* - client: the client name (e.g. "requests", "huggingface_hub", "openai", etc.). Default to "default" if client is not specified.
* - language: the language of the snippet (e.g. "sh", "js", "py", etc.)
*
* Example:
* ./packages/tasks-gen/snippets-fixtures/text-to-image/0.huggingface_hub.py
*/
import { existsSync as pathExists } from "node:fs";
import * as fs from "node:fs/promises";
import * as path from "node:path/posix";
import type { InferenceProviderOrPolicy } from "@huggingface/inference";
import { snippets } from "@huggingface/inference";
import type { InferenceSnippet, ModelDataMinimal, SnippetInferenceProvider, WidgetType } from "@huggingface/tasks";
import { inferenceSnippetLanguages } from "@huggingface/tasks";
const LANGUAGES = ["js", "python", "sh"] as const;
type Language = (typeof LANGUAGES)[number];
const EXTENSIONS: Record<Language, string> = { sh: "sh", js: "js", python: "py" };
const TEST_CASES: {
testName: string;
task: WidgetType;
model: ModelDataMinimal;
providers: InferenceProviderOrPolicy[];
lora?: boolean;
opts?: snippets.InferenceSnippetOptions;
}[] = [
{
testName: "automatic-speech-recognition",
task: "automatic-speech-recognition",
model: {
id: "openai/whisper-large-v3-turbo",
pipeline_tag: "automatic-speech-recognition",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "conversational-llm-non-stream",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference", "together"],
opts: { streaming: false },
},
{
testName: "conversational-llm-stream",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference", "together"],
opts: { streaming: true },
},
{
testName: "conversational-vlm-non-stream",
task: "conversational",
model: {
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
pipeline_tag: "image-text-to-text",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference", "fireworks-ai"],
opts: { streaming: false },
},
{
testName: "conversational-vlm-stream",
task: "conversational",
model: {
id: "meta-llama/Llama-3.2-11B-Vision-Instruct",
pipeline_tag: "image-text-to-text",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference", "fireworks-ai"],
opts: { streaming: true },
},
{
testName: "conversational-llm-custom-endpoint",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference"],
opts: { endpointUrl: "http://localhost:8080/v1" },
},
{
testName: "document-question-answering",
task: "document-question-answering",
model: {
id: "impira/layoutlm-invoices",
pipeline_tag: "document-question-answering",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "image-classification",
task: "image-classification",
model: {
id: "Falconsai/nsfw_image_detection",
pipeline_tag: "image-classification",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "image-to-image",
task: "image-to-image",
model: {
id: "black-forest-labs/FLUX.1-Kontext-dev",
pipeline_tag: "image-to-image",
tags: [],
inference: "",
},
providers: ["fal-ai", "replicate", "hf-inference"],
},
{
testName: "tabular",
task: "tabular-classification",
model: {
id: "templates/tabular-classification",
pipeline_tag: "tabular-classification",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "text-to-audio-transformers",
task: "text-to-audio",
model: {
id: "facebook/musicgen-small",
pipeline_tag: "text-to-audio",
tags: ["transformers"],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "text-to-image",
task: "text-to-image",
model: {
id: "black-forest-labs/FLUX.1-schnell",
pipeline_tag: "text-to-image",
tags: [],
inference: "",
},
providers: ["hf-inference", "fal-ai"],
},
{
testName: "text-to-video",
task: "text-to-video",
model: {
id: "tencent/HunyuanVideo",
pipeline_tag: "text-to-video",
tags: [],
inference: "",
},
providers: ["replicate", "fal-ai"],
},
{
testName: "text-classification",
task: "text-classification",
model: {
id: "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
pipeline_tag: "text-classification",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "basic-snippet--token-classification",
task: "token-classification",
model: {
id: "FacebookAI/xlm-roberta-large-finetuned-conll03-english",
pipeline_tag: "token-classification",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "zero-shot-classification",
task: "zero-shot-classification",
model: {
id: "facebook/bart-large-mnli",
pipeline_tag: "zero-shot-classification",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "zero-shot-image-classification",
task: "zero-shot-image-classification",
model: {
id: "openai/clip-vit-large-patch14",
pipeline_tag: "zero-shot-image-classification",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "text-to-image--lora",
task: "text-to-image",
model: {
id: "openfree/flux-chatgpt-ghibli-lora",
pipeline_tag: "text-to-image",
tags: ["lora", "base_model:adapter:black-forest-labs/FLUX.1-dev", "base_model:black-forest-labs/FLUX.1-dev"],
inference: "",
},
lora: true,
providers: ["fal-ai"],
},
{
testName: "bill-to-param",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference"],
opts: { billTo: "huggingface" },
},
{
testName: "with-access-token",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference"],
opts: { accessToken: "hf_xxx" },
},
{
testName: "explicit-direct-request",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["together"],
opts: { directRequest: true },
},
{
testName: "text-to-speech",
task: "text-to-speech",
model: {
id: "nari-labs/Dia-1.6B",
pipeline_tag: "text-to-speech",
tags: [],
inference: "",
},
providers: ["fal-ai"],
},
{
testName: "feature-extraction",
task: "feature-extraction",
model: {
id: "intfloat/multilingual-e5-large-instruct",
pipeline_tag: "feature-extraction",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "question-answering",
task: "question-answering",
model: {
id: "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad",
pipeline_tag: "question-answering",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
{
testName: "table-question-answering",
task: "table-question-answering",
model: {
id: "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad",
pipeline_tag: "table-question-answering",
tags: [],
inference: "",
},
providers: ["hf-inference"],
},
] as const;
const rootDirFinder = (): string => {
let currentPath = path.normalize(import.meta.url).replace("file:", "");
while (currentPath !== "/") {
if (pathExists(path.join(currentPath, "package.json"))) {
return currentPath;
}
currentPath = path.normalize(path.join(currentPath, ".."));
}
return "/";
};
function getFixtureFolder(testName: string): string {
return path.join(rootDirFinder(), "snippets-fixtures", testName);
}
function generateInferenceSnippet(
model: ModelDataMinimal,
language: Language,
provider: InferenceProviderOrPolicy,
task: WidgetType,
lora: boolean = false,
opts?: Record<string, unknown>
): InferenceSnippet[] {
const allSnippets = snippets.getInferenceSnippets(
model,
provider,
{
provider: provider,
hfModelId: model.id,
providerId: provider === "hf-inference" ? model.id : `<${provider} alias for ${model.id}>`,
status: "live",
task,
...(lora && task === "text-to-image"
? {
adapter: "lora",
adapterWeightsPath: `<path to LoRA weights in .safetensors format>`,
}
: {}),
},
opts
);
return allSnippets
.filter((snippet) => snippet.language == language)
.sort((snippetA, snippetB) => snippetA.client.localeCompare(snippetB.client));
}
async function getExpectedInferenceSnippet(
testName: string,
language: Language,
provider: SnippetInferenceProvider
): Promise<InferenceSnippet[]> {
const fixtureFolder = getFixtureFolder(testName);
const languageFolder = path.join(fixtureFolder, language);
if (!pathExists(languageFolder)) {
return [];
}
const files = await fs.readdir(languageFolder, { recursive: true });
const expectedSnippets: InferenceSnippet[] = [];
for (const file of files.filter((file) => file.includes(`.${provider}.`)).sort()) {
const client = file.split("/")[0]; // e.g. fal_client/1.fal-ai.python => fal_client
const content = await fs.readFile(path.join(languageFolder, file), { encoding: "utf-8" });
expectedSnippets.push({ language, client, content });
}
return expectedSnippets;
}
async function saveExpectedInferenceSnippet(
testName: string,
language: Language,
provider: SnippetInferenceProvider,
snippets: InferenceSnippet[]
) {
const fixtureFolder = getFixtureFolder(testName);
await fs.mkdir(fixtureFolder, { recursive: true });
const indexPerClient = new Map<string, number>();
for (const snippet of snippets) {
const extension = EXTENSIONS[language];
const client = snippet.client;
const index = indexPerClient.get(client) ?? 0;
indexPerClient.set(client, index + 1);
const file = path.join(fixtureFolder, language, snippet.client, `${index}.${provider}.${extension}`);
await fs.mkdir(path.dirname(file), { recursive: true });
await fs.writeFile(file, snippet.content);
}
}
if (import.meta.vitest) {
// Run test if in test mode
const { describe, expect, it } = import.meta.vitest;
describe("inference API snippets", () => {
TEST_CASES.forEach(({ testName, task, model, providers, lora, opts }) => {
describe(testName, () => {
inferenceSnippetLanguages.forEach((language) => {
providers.forEach((provider) => {
it(language, async () => {
const generatedSnippets = generateInferenceSnippet(model, language, provider, task, lora, opts);
const expectedSnippets = await getExpectedInferenceSnippet(testName, language, provider);
expect(generatedSnippets).toEqual(expectedSnippets);
});
});
});
});
});
});
} else {
// Otherwise, generate the fixtures
console.log("✨ Re-generating snippets");
console.debug(" 🚜 Removing existing fixtures...");
await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true });
console.debug(" 🏭 Generating new fixtures...");
TEST_CASES.forEach(({ testName, task, model, providers, lora, opts }) => {
console.debug(` ${testName} (${providers.join(", ")})`);
inferenceSnippetLanguages.forEach(async (language) => {
providers.forEach(async (provider) => {
const generatedSnippets = generateInferenceSnippet(model, language, provider, task, lora, opts);
await saveExpectedInferenceSnippet(testName, language, provider, generatedSnippets);
});
});
});
console.log("✅ All done!");
console.log("👉 Please check the generated fixtures before committing them.");
}