export function getKeyValueShapes()

in src/configs.js [262:335]


export function getKeyValueShapes(config, {
    prefix = 'past_key_values',
    batch_size=1,
} = {}) {
    /** @type {Record<string, number[]>} */
    const decoderFeeds = {};
    const normalized_config = config.normalized_config;

    if (normalized_config.is_encoder_decoder && (
        'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config
    )) {
        const encoder_dim_kv = normalized_config.encoder_dim_kv ?? (
            normalized_config.encoder_hidden_size / normalized_config.num_encoder_heads
        );
        const decoder_dim_kv = normalized_config.decoder_dim_kv ?? (
            normalized_config.decoder_hidden_size / normalized_config.num_decoder_heads
        );

        const encoder_dims = [batch_size, normalized_config.num_encoder_heads, 0, encoder_dim_kv];
        const decoder_dims = [batch_size, normalized_config.num_decoder_heads, 0, decoder_dim_kv];
        for (let i = 0; i < normalized_config.num_decoder_layers; ++i) {
            decoderFeeds[`${prefix}.${i}.encoder.key`] = encoder_dims;
            decoderFeeds[`${prefix}.${i}.encoder.value`] = encoder_dims;
            decoderFeeds[`${prefix}.${i}.decoder.key`] = decoder_dims;
            decoderFeeds[`${prefix}.${i}.decoder.value`] = decoder_dims;
        }
    } else { // Decoders
        const num_heads = normalized_config.num_heads;
        const num_layers = normalized_config.num_layers;
        const dim_kv = normalized_config.dim_kv ?? (
            normalized_config.hidden_size /
            (normalized_config.num_attention_heads ?? num_heads)
        );

        if (normalized_config.model_type === 'falcon') {
            // NOTE: Custom implementation for Falcon
            const dims = [batch_size * num_heads, 0, dim_kv]
            for (let i = 0; i < num_layers; ++i) {
                decoderFeeds[`${prefix}.${i}.key`] = dims;
                decoderFeeds[`${prefix}.${i}.value`] = dims;
            }
        } else if (normalized_config.multi_query) { // e.g., for `gpt_bigcode`
            const dims = [batch_size * num_heads, 0, 2 * dim_kv]

            for (let i = 0; i < num_layers; ++i) {
                decoderFeeds[`${prefix}.${i}.key_value`] = dims;
            }
        } else if (normalized_config.model_type === 'bloom') {
            // NOTE: Custom implementation for Bloom

            const keyDims = [batch_size * num_heads, dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
            const valueDims = [batch_size * num_heads, 0, dim_kv] // [batch_size x num_heads,past_sequence_length,64]
            for (let i = 0; i < num_layers; ++i) {
                decoderFeeds[`${prefix}.${i}.key`] = keyDims;
                decoderFeeds[`${prefix}.${i}.value`] = valueDims;
            }
        } else if (normalized_config.model_type === 'openelm') {
            for (let i = 0; i < num_layers; ++i) {
                const dims = [batch_size, num_heads[i], 0, dim_kv]

                decoderFeeds[`${prefix}.${i}.key`] = dims;
                decoderFeeds[`${prefix}.${i}.value`] = dims;
            }
        } else { // Decoder-only
            const dims = [batch_size, num_heads, 0, dim_kv]
            for (let i = 0; i < num_layers; ++i) {
                decoderFeeds[`${prefix}.${i}.key`] = dims;
                decoderFeeds[`${prefix}.${i}.value`] = dims;
            }
        }
    }

    return decoderFeeds;
}