async function genericTextToTextForward()

in src/models.js [724:799]


async function genericTextToTextForward(self, {
    // Generic parameters:
    encode_function,
    merge_function,
    modality_input_name,
    modality_output_name,

    // Produced by the tokenizer/processor:
    input_ids = null,
    attention_mask = null,

    // Used during generation:
    position_ids = null,
    inputs_embeds = null,
    past_key_values = null,

    // Generic generation parameters
    generation_config = null,
    logits_processor = null,

    // Additional parameters
    ...kwargs
}) {
    const modality_values = kwargs[modality_input_name];
    if (!inputs_embeds) {
        // 1. Extract the text embeddings.
        inputs_embeds = await self.encode_text({ input_ids, ...kwargs });

        // 2. Possibly, merge text and modality values
        if (modality_values && input_ids.dims[1] !== 1) {
            const modality_features = await encode_function({
                // Pass the modality values under its expected key.
                // The caller knows whether this is audio or image.
                [modality_input_name]: modality_values,
                ...kwargs
            });
            ({ inputs_embeds, attention_mask } = merge_function({
                [modality_output_name]: modality_features,
                inputs_embeds,
                input_ids,
                attention_mask,
            }));

        } else if (past_key_values && modality_values && input_ids.dims[1] === 1) {
            // This branch handles the cache case.
            const target_length = input_ids.dims[1]; // always 1
            const past_length = Object.values(past_key_values)[0].dims.at(-2);

            attention_mask = cat([
                ones([input_ids.dims[0], past_length]),
                attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
            ], 1);
        }
    }

    if (!position_ids) {

        if (self.config.model_type === 'qwen2_vl') {
            // Special case for qwen2_vl models
            // @ts-ignore
            const { image_grid_thw, video_grid_thw } = kwargs;
            [position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
        }
    }

    // 3. Call the decoder forward using the updated inputs.
    const outputs = await decoderForward(self, {
        inputs_embeds,
        past_key_values,
        attention_mask,
        position_ids,
        generation_config,
        logits_processor,
    }, true);
    return outputs;
}