static async from_pretrained()

in src/models.js [1116:1289]


    static async from_pretrained(pretrained_model_name_or_path, {
        progress_callback = null,
        config = null,
        cache_dir = null,
        local_files_only = false,
        revision = 'main',
        model_file_name = null,
        subfolder = 'onnx',
        device = null,
        dtype = null,
        use_external_data_format = null,
        session_options = {},
    } = {}) {

        let options = {
            progress_callback,
            config,
            cache_dir,
            local_files_only,
            revision,
            model_file_name,
            subfolder,
            device,
            dtype,
            use_external_data_format,
            session_options,
        }

        const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
        const modelType = MODEL_TYPE_MAPPING.get(modelName);

        config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);

        let info;
        if (modelType === MODEL_TYPES.DecoderOnly) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    model: options.model_file_name ?? 'model',
                }, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);

        } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    model: 'encoder_model',
                    decoder_model_merged: 'decoder_model_merged',
                }, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);

        } else if (modelType === MODEL_TYPES.MaskGeneration) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    model: 'vision_encoder',
                    prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
                }, options),
            ]);

        } else if (modelType === MODEL_TYPES.EncoderDecoder) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    model: 'encoder_model',
                    decoder_model_merged: 'decoder_model_merged',
                }, options),
            ]);

        } else if (modelType === MODEL_TYPES.ImageTextToText) {
            const sessions = {
                embed_tokens: 'embed_tokens',
                vision_encoder: 'vision_encoder',
                decoder_model_merged: 'decoder_model_merged',
            }
            if (config.is_encoder_decoder) {
                sessions['model'] = 'encoder_model';
            }
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, sessions, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);

        } else if (modelType === MODEL_TYPES.AudioTextToText) {
            const sessions = {
                embed_tokens: 'embed_tokens',
                audio_encoder: 'audio_encoder',
                decoder_model_merged: 'decoder_model_merged',
            }
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, sessions, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);
        } else if (modelType === MODEL_TYPES.ImageAudioTextToText) {
            const sessions = {
                embed_tokens: 'embed_tokens',
                audio_encoder: 'audio_encoder',
                vision_encoder: 'vision_encoder',
                decoder_model_merged: 'decoder_model_merged',
            }
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, sessions, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);
        } else if (modelType === MODEL_TYPES.Musicgen) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    model: 'text_encoder',
                    decoder_model_merged: 'decoder_model_merged',
                    encodec_decode: 'encodec_decode',
                }, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);

        } else if (modelType === MODEL_TYPES.MultiModality) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    prepare_inputs_embeds: 'prepare_inputs_embeds',
                    model: 'language_model',
                    lm_head: 'lm_head',
                    gen_head: 'gen_head',
                    gen_img_embeds: 'gen_img_embeds',
                    image_decode: 'image_decode',
                }, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);

        } else if (modelType === MODEL_TYPES.Phi3V) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    prepare_inputs_embeds: 'prepare_inputs_embeds',
                    model: 'model',
                    vision_encoder: 'vision_encoder',
                }, options),
                getOptionalConfigs(pretrained_model_name_or_path, {
                    generation_config: 'generation_config.json',
                }, options),
            ]);
        } else if (modelType === MODEL_TYPES.AutoEncoder) {
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    encoder_model: 'encoder_model',
                    decoder_model: 'decoder_model',
                }, options),
            ]);
        } else { // should be MODEL_TYPES.EncoderOnly
            if (modelType !== MODEL_TYPES.EncoderOnly) {
                const type = modelName ?? config?.model_type;
                if (type !== 'custom') {
                    console.warn(`Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
                }
            }
            info = await Promise.all([
                constructSessions(pretrained_model_name_or_path, {
                    model: options.model_file_name ?? 'model',
                }, options),
            ]);
        }

        // @ts-ignore
        return new this(config, ...info);
    }