in src/models.js [1116:1289]
static async from_pretrained(pretrained_model_name_or_path, {
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
subfolder = 'onnx',
device = null,
dtype = null,
use_external_data_format = null,
session_options = {},
} = {}) {
let options = {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
subfolder,
device,
dtype,
use_external_data_format,
session_options,
}
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'vision_encoder',
prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
}, options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
]);
} else if (modelType === MODEL_TYPES.ImageTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
}
if (config.is_encoder_decoder) {
sessions['model'] = 'encoder_model';
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.AudioTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
audio_encoder: 'audio_encoder',
decoder_model_merged: 'decoder_model_merged',
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.ImageAudioTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
audio_encoder: 'audio_encoder',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Musicgen) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'text_encoder',
decoder_model_merged: 'decoder_model_merged',
encodec_decode: 'encodec_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.MultiModality) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'language_model',
lm_head: 'lm_head',
gen_head: 'gen_head',
gen_img_embeds: 'gen_img_embeds',
image_decode: 'image_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Phi3V) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'model',
vision_encoder: 'vision_encoder',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.AutoEncoder) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
encoder_model: 'encoder_model',
decoder_model: 'decoder_model',
}, options),
]);
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
const type = modelName ?? config?.model_type;
if (type !== 'custom') {
console.warn(`Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
}
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
]);
}
// @ts-ignore
return new this(config, ...info);
}