in src/models.js [724:799]
async function genericTextToTextForward(self, {
// Generic parameters:
encode_function,
merge_function,
modality_input_name,
modality_output_name,
// Produced by the tokenizer/processor:
input_ids = null,
attention_mask = null,
// Used during generation:
position_ids = null,
inputs_embeds = null,
past_key_values = null,
// Generic generation parameters
generation_config = null,
logits_processor = null,
// Additional parameters
...kwargs
}) {
const modality_values = kwargs[modality_input_name];
if (!inputs_embeds) {
// 1. Extract the text embeddings.
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
// 2. Possibly, merge text and modality values
if (modality_values && input_ids.dims[1] !== 1) {
const modality_features = await encode_function({
// Pass the modality values under its expected key.
// The caller knows whether this is audio or image.
[modality_input_name]: modality_values,
...kwargs
});
({ inputs_embeds, attention_mask } = merge_function({
[modality_output_name]: modality_features,
inputs_embeds,
input_ids,
attention_mask,
}));
} else if (past_key_values && modality_values && input_ids.dims[1] === 1) {
// This branch handles the cache case.
const target_length = input_ids.dims[1]; // always 1
const past_length = Object.values(past_key_values)[0].dims.at(-2);
attention_mask = cat([
ones([input_ids.dims[0], past_length]),
attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
], 1);
}
}
if (!position_ids) {
if (self.config.model_type === 'qwen2_vl') {
// Special case for qwen2_vl models
// @ts-ignore
const { image_grid_thw, video_grid_thw } = kwargs;
[position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
}
}
// 3. Call the decoder forward using the updated inputs.
const outputs = await decoderForward(self, {
inputs_embeds,
past_key_values,
attention_mask,
position_ids,
generation_config,
logits_processor,
}, true);
return outputs;
}