packages/gguf/scripts/generate-llm.ts (226 lines of code) (raw):
/**
* Script for generating llm.ts
* The source data is taken from llama.cpp
*/
import { writeFileSync } from "node:fs";
const SOURCE_CPP_URLS = [
"https://raw.githubusercontent.com/ggml-org/llama.cpp/master/src/llama-arch.cpp",
"https://raw.githubusercontent.com/ggml-org/llama.cpp/master/src/llama-model.cpp",
];
const DEST_FILE_PATH = "./src/transformer-llm.ts";
const DEST_COMMON_SOURCE = `
/** This file is auto-generated by generate-llm.ts */
import type { ModelBase, GGUFGeneralInfo } from "./types";
type LLMBase<TArchitecture extends string> = Partial<Record<
\`\${TArchitecture}.vocab_size\`
| \`\${TArchitecture}.use_parallel_residual\`
| \`\${TArchitecture}.tensor_data_layout\`,
number
>>;
type Attention<TArchitecture extends string> = Record<
\`\${TArchitecture}.attention.head_count\`,
number
> & Partial<Record<
\`\${TArchitecture}.attention.head_count_kv\`
| \`\${TArchitecture}.attention.key_length\`
| \`\${TArchitecture}.attention.value_length\`,
number
>>;
export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn";
type Rope<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.rope.dimension_count\`
| \`\${TArchitecture}.rope.freq_base\`
| \`\${TArchitecture}.rope.scale_linear\`
| \`\${TArchitecture}.rope.scaling.factor\`
| \`\${TArchitecture}.rope.scaling.original_context_length\`,
number
>
& Record<\`\${TArchitecture}.rope.scaling.type\`, TransformerLLMRopeScalingType>
& Record<\`\${TArchitecture}.rope.finetuned\`, boolean>
>;
type MOE<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.expert_count\`
| \`\${TArchitecture}.expert_used_count\`,
number
>
>;
export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture>
& LLMBase<TArchitecture>
& ModelBase<TArchitecture>
& MOE<TArchitecture>
& Attention<TArchitecture>
& Rope<TArchitecture>;
export enum TransformerLLMPoolingType {
UNSPECIFIED = -1,
NONE = 0,
MEAN = 1,
CLS = 2,
};
`;
const KV_TYPE = {
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS: "number",
LLM_KV_ATTENTION_LAYERNORM_EPS: "number",
LLM_KV_ATTENTION_CAUSAL: "boolean",
LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT: "number",
LLM_KV_POOLING_TYPE: "TransformerLLMPoolingType",
LLM_KV_ATTENTION_CLAMP_KQV: "number",
LLM_KV_ATTENTION_MAX_ALIBI_BIAS: "number",
LLM_KV_SSM_CONV_KERNEL: "number",
LLM_KV_SSM_INNER_SIZE: "number",
LLM_KV_SSM_STATE_SIZE: "number",
LLM_KV_SSM_TIME_STEP_RANK: "number",
LLM_KV_LOGIT_SCALE: "number",
LLM_KV_EXPERT_FEED_FORWARD_LENGTH: "number",
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH: "number",
LLM_KV_ATTENTION_SLIDING_WINDOW: "number",
LLM_KV_ATTN_LOGIT_SOFTCAPPING: "number",
LLM_KV_FINAL_LOGIT_SOFTCAPPING: "number",
LLM_KV_LEADING_DENSE_BLOCK_COUNT: "number",
LLM_KV_ATTENTION_KV_LORA_RANK: "number",
LLM_KV_EXPERT_SHARED_COUNT: "number",
LLM_KV_EXPERT_WEIGHTS_SCALE: "number",
LLM_KV_ROPE_SCALING_YARN_LOG_MUL: "number",
LLM_KV_ROPE_DIMENSION_COUNT: "number",
LLM_KV_ROPE_DIMENSION_SECTIONS: "number[]",
LLM_KV_ATTENTION_Q_LORA_RANK: "number",
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT: "number",
LLM_KV_DECODER_START_TOKEN_ID: "number",
LLM_KV_USE_PARALLEL_RESIDUAL: "boolean",
LLM_KV_WKV_HEAD_SIZE: "number",
LLM_KV_TIME_MIX_EXTRA_DIM: "number",
LLM_KV_TIME_DECAY_EXTRA_DIM: "number",
LLM_KV_RESCALE_EVERY_N_LAYERS: "boolean",
LLM_KV_TOKEN_SHIFT_COUNT: "boolean",
LLM_KV_SWIN_NORM: "boolean",
LLM_KV_ATTENTION_GROUPNORM_EPS: "number",
LLM_KV_ATTENTION_GROUPNORM_GROUPS: "number",
LLM_KV_ATTENTION_SCALE: "number",
LLM_KV_EMBEDDING_SCALE: "number",
LLM_KV_RESIDUAL_SCALE: "number",
LLM_KV_SSM_DT_B_C_RMS: "boolean",
LLM_KV_EXPERT_WEIGHTS_NORM: "boolean",
LLM_KV_EXPERT_GATING_FUNC: "boolean",
};
interface Arch {
cppConst: string; // for example: "LLM_ARCH_LLAMA"
name: string; // for example: "llama"
tsName: string; // for example: "ArchLlama"
tensorNames: string[]; // for example: "token_embd"
hparams: string[];
}
async function main() {
const cppSources = await Promise.all(
SOURCE_CPP_URLS.map(async (url) => {
const res = await fetch(url);
return await res.text();
})
);
const cppSource = cppSources.join("\n");
/////////////////////////////////////
// extract list of all architectures
const archList: Arch[] = [];
const RE_ARCH_NAME = /LLM_ARCH_[A-Z0-9_]+/;
const matchedArchList = cppSource.match(/LLM_ARCH_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n");
if (!matchedArchList?.length) {
throw new Error("LLM_ARCH_NAMES is empty");
}
for (const line of matchedArchList) {
const matched = line.match(/(?<cppConst>LLM_ARCH_[A-Z0-9_]+),\s+"(?<name>.+?)"/);
if (matched?.groups && !matched.groups.name.match(/unknown/)) {
archList.push({
cppConst: matched.groups.cppConst,
name: matched.groups.name,
tsName: snakeToPascal(matched.groups.cppConst.replace("LLM_", "")),
tensorNames: [],
hparams: [],
});
}
}
/////////////////////////////////////
// extract map constant name to kv name
// for example: LLM_KV_ATTENTION_LAYERNORM_RMS_EPS ==> "%s.attention.layer_norm_rms_epsilon"
const constToKVName: { [cppConst: string]: string } = {};
const matchedKVList = cppSource.match(/LLM_KV_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n");
if (!matchedKVList?.length) {
throw new Error("LLM_KV_NAMES is empty");
}
for (const line of matchedKVList) {
const matched = line.match(/(?<cppConst>LLM_KV_[A-Z0-9_]+)[,\s]+"(?<name>.+?)"/);
if (matched?.groups) {
constToKVName[matched.groups.cppConst] = matched.groups.name;
}
}
console.log("constToKVName", constToKVName);
/////////////////////////////////////
// extract list of tensor names based on architecture
// TODO: unused for now
const matchedTensorList = cppSource.match(/LLM_TENSOR_NAMES = (?<names>[^;]+)/)?.groups?.names.split("\n");
if (!matchedTensorList?.length) {
throw new Error("LLM_TENSOR_NAMES is empty");
}
let currCppConst = "";
for (const line of matchedTensorList) {
// check if current line has LLM_ARCH_*
const cppConst = line.match(RE_ARCH_NAME)?.[0];
if (cppConst) {
currCppConst = cppConst;
continue;
}
// check if current line has LLM_TENSOR_*
const tensorMatched = line.match(/LLM_TENSOR_[A-Z0-9_]+[,\s]+"(?<name>.+?)"/);
if (tensorMatched?.groups) {
const arch = archList.find((a) => a.cppConst === currCppConst);
if (arch) arch.tensorNames.push(tensorMatched.groups.name);
}
}
/////////////////////////////////////
// extract list of hyper params based on architecture
let insideLoadHParamsFn = false;
currCppConst = "";
for (const line of cppSource.split("\n")) {
// check if current line is function llama_model::load_hparams()
if (line.startsWith("void llama_model::load_hparams")) {
insideLoadHParamsFn = true;
}
if (!insideLoadHParamsFn) {
continue;
}
// check if current line has LLM_ARCH_*
const RE_CASE = new RegExp(`case (${RE_ARCH_NAME.source})`);
const cppConst = line.match(RE_CASE)?.[1];
if (cppConst) {
currCppConst = cppConst;
continue;
}
// check if current line has get_key(...)
const keyConst = line.match(/LLM_KV_[A-Z0-9_]+/)?.[0];
if (keyConst) {
const arch = archList.find((a) => a.cppConst === currCppConst);
if (arch) {
arch.hparams.push(keyConst);
}
}
// check if current line is end-of-function
if (line === "}") {
break;
}
}
/////////////////////////////////////
// write result to file
const content = [
DEST_COMMON_SOURCE,
"export const LLM_ARCHITECTURES = [",
...archList.map((a) => `\t${JSON.stringify(a.name)},`),
"] as const;",
"type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];",
...archList.map((a) => {
let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`;
if (a.hparams.length) {
code += [
" & {",
...a.hparams.map((k) => {
if (!KV_TYPE[k]) {
throw new Error(`Cannot find type definition of ${k}`);
}
return `\t${JSON.stringify(constToKVName[k].replace("%s", a.name))}: ${KV_TYPE[k]},`;
}),
"};",
].join("\n");
} else {
code += ";";
}
return code;
}),
"",
`export type TransformerLLM = ${archList.map((a) => a.tsName).join(" | ")};`,
].join("\n");
writeFileSync(DEST_FILE_PATH, content);
}
function snakeToPascal(str: string) {
return str
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1).toLowerCase())
.join("");
}
main();