packages/gguf/scripts/generate-llm.ts (226 lines of code) (raw):

/** * Script for generating llm.ts * The source data is taken from llama.cpp */ import { writeFileSync } from "node:fs"; const SOURCE_CPP_URLS = [ "https://raw.githubusercontent.com/ggml-org/llama.cpp/master/src/llama-arch.cpp", "https://raw.githubusercontent.com/ggml-org/llama.cpp/master/src/llama-model.cpp", ]; const DEST_FILE_PATH = "./src/transformer-llm.ts"; const DEST_COMMON_SOURCE = ` /** This file is auto-generated by generate-llm.ts */ import type { ModelBase, GGUFGeneralInfo } from "./types"; type LLMBase<TArchitecture extends string> = Partial<Record< \`\${TArchitecture}.vocab_size\` | \`\${TArchitecture}.use_parallel_residual\` | \`\${TArchitecture}.tensor_data_layout\`, number >>; type Attention<TArchitecture extends string> = Record< \`\${TArchitecture}.attention.head_count\`, number > & Partial<Record< \`\${TArchitecture}.attention.head_count_kv\` | \`\${TArchitecture}.attention.key_length\` | \`\${TArchitecture}.attention.value_length\`, number >>; export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn"; type Rope<TArchitecture extends LLMArchitecture> = Partial< Record< \`\${TArchitecture}.rope.dimension_count\` | \`\${TArchitecture}.rope.freq_base\` | \`\${TArchitecture}.rope.scale_linear\` | \`\${TArchitecture}.rope.scaling.factor\` | \`\${TArchitecture}.rope.scaling.original_context_length\`, number > & Record<\`\${TArchitecture}.rope.scaling.type\`, TransformerLLMRopeScalingType> & Record<\`\${TArchitecture}.rope.finetuned\`, boolean> >; type MOE<TArchitecture extends LLMArchitecture> = Partial< Record< \`\${TArchitecture}.expert_count\` | \`\${TArchitecture}.expert_used_count\`, number > >; export type TransformerLLMArchitecture = LLMArchitecture; // type alias export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture> & LLMBase<TArchitecture> & ModelBase<TArchitecture> & MOE<TArchitecture> & Attention<TArchitecture> & Rope<TArchitecture>; export enum TransformerLLMPoolingType { UNSPECIFIED = -1, NONE = 0, MEAN = 1, CLS = 2, }; `; const KV_TYPE = { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS: "number", LLM_KV_ATTENTION_LAYERNORM_EPS: "number", LLM_KV_ATTENTION_CAUSAL: "boolean", LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT: "number", LLM_KV_POOLING_TYPE: "TransformerLLMPoolingType", LLM_KV_ATTENTION_CLAMP_KQV: "number", LLM_KV_ATTENTION_MAX_ALIBI_BIAS: "number", LLM_KV_SSM_CONV_KERNEL: "number", LLM_KV_SSM_INNER_SIZE: "number", LLM_KV_SSM_STATE_SIZE: "number", LLM_KV_SSM_TIME_STEP_RANK: "number", LLM_KV_LOGIT_SCALE: "number", LLM_KV_EXPERT_FEED_FORWARD_LENGTH: "number", LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH: "number", LLM_KV_ATTENTION_SLIDING_WINDOW: "number", LLM_KV_ATTN_LOGIT_SOFTCAPPING: "number", LLM_KV_FINAL_LOGIT_SOFTCAPPING: "number", LLM_KV_LEADING_DENSE_BLOCK_COUNT: "number", LLM_KV_ATTENTION_KV_LORA_RANK: "number", LLM_KV_EXPERT_SHARED_COUNT: "number", LLM_KV_EXPERT_WEIGHTS_SCALE: "number", LLM_KV_ROPE_SCALING_YARN_LOG_MUL: "number", LLM_KV_ROPE_DIMENSION_COUNT: "number", LLM_KV_ROPE_DIMENSION_SECTIONS: "number[]", LLM_KV_ATTENTION_Q_LORA_RANK: "number", LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT: "number", LLM_KV_DECODER_START_TOKEN_ID: "number", LLM_KV_USE_PARALLEL_RESIDUAL: "boolean", LLM_KV_WKV_HEAD_SIZE: "number", LLM_KV_TIME_MIX_EXTRA_DIM: "number", LLM_KV_TIME_DECAY_EXTRA_DIM: "number", LLM_KV_RESCALE_EVERY_N_LAYERS: "boolean", LLM_KV_TOKEN_SHIFT_COUNT: "boolean", LLM_KV_SWIN_NORM: "boolean", LLM_KV_ATTENTION_GROUPNORM_EPS: "number", LLM_KV_ATTENTION_GROUPNORM_GROUPS: "number", LLM_KV_ATTENTION_SCALE: "number", LLM_KV_EMBEDDING_SCALE: "number", LLM_KV_RESIDUAL_SCALE: "number", LLM_KV_SSM_DT_B_C_RMS: "boolean", LLM_KV_EXPERT_WEIGHTS_NORM: "boolean", LLM_KV_EXPERT_GATING_FUNC: "boolean", }; interface Arch { cppConst: string; // for example: "LLM_ARCH_LLAMA" name: string; // for example: "llama" tsName: string; // for example: "ArchLlama" tensorNames: string[]; // for example: "token_embd" hparams: string[]; } async function main() { const cppSources = await Promise.all( SOURCE_CPP_URLS.map(async (url) => { const res = await fetch(url); return await res.text(); }) ); const cppSource = cppSources.join("\n"); ///////////////////////////////////// // extract list of all architectures const archList: Arch[] = []; const RE_ARCH_NAME = /LLM_ARCH_[A-Z0-9_]+/; const matchedArchList = cppSource.match(/LLM_ARCH_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n"); if (!matchedArchList?.length) { throw new Error("LLM_ARCH_NAMES is empty"); } for (const line of matchedArchList) { const matched = line.match(/(?<cppConst>LLM_ARCH_[A-Z0-9_]+),\s+"(?<name>.+?)"/); if (matched?.groups && !matched.groups.name.match(/unknown/)) { archList.push({ cppConst: matched.groups.cppConst, name: matched.groups.name, tsName: snakeToPascal(matched.groups.cppConst.replace("LLM_", "")), tensorNames: [], hparams: [], }); } } ///////////////////////////////////// // extract map constant name to kv name // for example: LLM_KV_ATTENTION_LAYERNORM_RMS_EPS ==> "%s.attention.layer_norm_rms_epsilon" const constToKVName: { [cppConst: string]: string } = {}; const matchedKVList = cppSource.match(/LLM_KV_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n"); if (!matchedKVList?.length) { throw new Error("LLM_KV_NAMES is empty"); } for (const line of matchedKVList) { const matched = line.match(/(?<cppConst>LLM_KV_[A-Z0-9_]+)[,\s]+"(?<name>.+?)"/); if (matched?.groups) { constToKVName[matched.groups.cppConst] = matched.groups.name; } } console.log("constToKVName", constToKVName); ///////////////////////////////////// // extract list of tensor names based on architecture // TODO: unused for now const matchedTensorList = cppSource.match(/LLM_TENSOR_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n"); if (!matchedTensorList?.length) { throw new Error("LLM_TENSOR_NAMES is empty"); } let currCppConst = ""; for (const line of matchedTensorList) { // check if current line has LLM_ARCH_* const cppConst = line.match(RE_ARCH_NAME)?.[0]; if (cppConst) { currCppConst = cppConst; continue; } // check if current line has LLM_TENSOR_* const tensorMatched = line.match(/LLM_TENSOR_[A-Z0-9_]+[,\s]+"(?<name>.+?)"/); if (tensorMatched?.groups) { const arch = archList.find((a) => a.cppConst === currCppConst); if (arch) arch.tensorNames.push(tensorMatched.groups.name); } } ///////////////////////////////////// // extract list of hyper params based on architecture let insideLoadHParamsFn = false; currCppConst = ""; for (const line of cppSource.split("\n")) { // check if current line is function llama_model::load_hparams() if (line.startsWith("void llama_model::load_hparams")) { insideLoadHParamsFn = true; } if (!insideLoadHParamsFn) { continue; } // check if current line has LLM_ARCH_* const RE_CASE = new RegExp(`case (${RE_ARCH_NAME.source})`); const cppConst = line.match(RE_CASE)?.[1]; if (cppConst) { currCppConst = cppConst; continue; } // check if current line has get_key(...) const keyConst = line.match(/LLM_KV_[A-Z0-9_]+/)?.[0]; if (keyConst) { const arch = archList.find((a) => a.cppConst === currCppConst); if (arch) { arch.hparams.push(keyConst); } } // check if current line is end-of-function if (line === "}") { break; } } ///////////////////////////////////// // write result to file const content = [ DEST_COMMON_SOURCE, "export const LLM_ARCHITECTURES = [", ...archList.map((a) => `\t${JSON.stringify(a.name)},`), "] as const;", "type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];", ...archList.map((a) => { let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`; if (a.hparams.length) { code += [ " & {", ...a.hparams.map((k) => { if (!KV_TYPE[k]) { throw new Error(`Cannot find type definition of ${k}`); } return `\t${JSON.stringify(constToKVName[k].replace("%s", a.name))}: ${KV_TYPE[k]},`; }), "};", ].join("\n"); } else { code += ";"; } return code; }), "", `export type TransformerLLM = ${archList.map((a) => a.tsName).join(" | ")};`, ].join("\n"); writeFileSync(DEST_FILE_PATH, content); } function snakeToPascal(str: string) { return str .split("_") .map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase()) .join(""); } main();