async function getSession()

in src/models.js [161:338]


async function getSession(pretrained_model_name_or_path, fileName, options) {
    let custom_config = options.config?.['transformers.js_config'] ?? {};

    let device = options.device ?? custom_config.device;
    if (device && typeof device !== 'string') {
        if (device.hasOwnProperty(fileName)) {
            device = device[fileName];
        } else {
            console.warn(`device not specified for "${fileName}". Using the default device.`);
            device = null;
        }
    }

    // If the device is not specified, we use the default (supported) execution providers.
    const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
        device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
    );

    const executionProviders = deviceToExecutionProviders(selectedDevice);

    // Update custom config with the selected device's config, if it exists
    const device_config = custom_config.device_config ?? {};
    if (device_config.hasOwnProperty(selectedDevice)) {
        custom_config = {
            ...custom_config,
            ...device_config[selectedDevice],
        };
    }

    // If options.dtype is specified, we use it to choose the suffix for the model file.
    // Otherwise, we use the default dtype for the device.
    let dtype = options.dtype ?? custom_config.dtype;
    if (typeof dtype !== 'string') {
        if (dtype && dtype.hasOwnProperty(fileName)) {
            dtype = dtype[fileName];
        } else {
            dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
            console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`);
        }
    }

    if (dtype === DATA_TYPES.auto) {
        // Try to choose the auto dtype based on the custom config
        let config_dtype = custom_config.dtype;
        if (typeof config_dtype !== 'string') {
            config_dtype = config_dtype?.[fileName];
        }

        if (config_dtype && config_dtype !== DATA_TYPES.auto && DATA_TYPES.hasOwnProperty(config_dtype)) {
            // Defined by the config, and is not "auto"
            dtype = config_dtype;
        } else {
            // Choose default dtype based on device, falling back to fp32
            dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
        }
    }

    const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);

    if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
        throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`);
    } else if (selectedDtype === DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await isWebGpuFp16Supported())) {
        throw new Error(`The device (${selectedDevice}) does not support fp16.`);
    }

    // Only valid for models with a decoder
    const kv_cache_dtype_config = custom_config.kv_cache_dtype;
    const kv_cache_dtype = kv_cache_dtype_config
        ? (typeof kv_cache_dtype_config === 'string'
            ? kv_cache_dtype_config
            : kv_cache_dtype_config[selectedDtype] ?? 'float32')
        : undefined;

    if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
        throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
    }

    const session_config = {
        dtype: selectedDtype,
        kv_cache_dtype,
        device: selectedDevice,
    }

    // Construct the model file name
    const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
    const baseName = `${fileName}${suffix}.onnx`;
    const modelFileName = `${options.subfolder ?? ''}/${baseName}`;

    const session_options = { ...options.session_options };

    // Overwrite `executionProviders` if not specified
    session_options.executionProviders ??= executionProviders;

    // Overwrite `freeDimensionOverrides` if specified in config and not set in session options
    const free_dimension_overrides = custom_config.free_dimension_overrides;
    if (free_dimension_overrides) {
        session_options.freeDimensionOverrides ??= free_dimension_overrides;
    } else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
        console.warn(
            `WebNN does not currently support dynamic shapes and requires 'free_dimension_overrides' to be set in config.json, preferably as a field within config["transformers.js_config"]["device_config"]["${selectedDevice}"]. ` +
            `When 'free_dimension_overrides' is not set, you may experience significant performance degradation.`
        );
    }

    const return_path = apis.IS_NODE_ENV && env.useFSCache;
    const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, return_path);

    // Handle onnx external data files
    const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
    /** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
    let externalDataPromises = [];
    if (use_external_data_format) {
        let external_data_format;
        if (typeof use_external_data_format === 'object') {
            if (use_external_data_format.hasOwnProperty(baseName)) {
                external_data_format = use_external_data_format[baseName];
            } else if (use_external_data_format.hasOwnProperty(fileName)) {
                external_data_format = use_external_data_format[fileName];
            } else {
                external_data_format = false;
            }
        } else {
            external_data_format = use_external_data_format;
        }

        const num_chunks = +external_data_format; // (false=0, true=1, number remains the same)
        if (num_chunks > MAX_EXTERNAL_DATA_CHUNKS) {
            throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`);
        }
        for (let i = 0; i < num_chunks; ++i) {
            const path = `${baseName}_data${i === 0 ? '' : '_' + i}`;
            const fullPath = `${options.subfolder ?? ''}/${path}`;
            externalDataPromises.push(new Promise(async (resolve, reject) => {
                const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, return_path);
                resolve(data instanceof Uint8Array ? { path, data } : path);
            }));
        }

    } else if (session_options.externalData !== undefined) {
        externalDataPromises = session_options.externalData.map(async (ext) => {
            // if the external data is a string, fetch the file and replace the string with its content
            // @ts-expect-error TS2339
            if (typeof ext.data === "string") {
                // @ts-expect-error TS2339
                const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options);
                // @ts-expect-error TS2698
                return { ...ext, data: ext_buffer };
            }
            return ext;
        });
    }

    if (externalDataPromises.length > 0) {
        const externalData = await Promise.all(externalDataPromises);
        if (!apis.IS_NODE_ENV) {
            session_options.externalData = externalData;
        }
    }

    if (selectedDevice === 'webgpu') {
        const shapes = getKeyValueShapes(options.config, {
            prefix: 'present',
        });
        if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
            // Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
            /** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
            const preferredOutputLocation = {};
            for (const key in shapes) {
                preferredOutputLocation[key] = 'gpu-buffer';
            }
            session_options.preferredOutputLocation = preferredOutputLocation;
        }
    }

    const buffer_or_path = await bufferOrPathPromise;

    return { buffer_or_path, session_options, session_config };
}