async _prepare_encoder_decoder_kwargs_for_generation()

in src/models.js [1640:1688]


    async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name, generation_config }) {
        if (
            this.sessions['model'].inputNames.includes('inputs_embeds')
            && !model_inputs.inputs_embeds
            && '_prepare_inputs_embeds' in this
        ) {
            // Encoder expects `inputs_embeds` instead of `input_ids`
            const { input_ids, pixel_values, attention_mask, ...kwargs } = model_inputs;
            // @ts-ignore
            const prepared_inputs = await this._prepare_inputs_embeds(model_inputs);
            model_inputs = {
                ...kwargs,
                ...pick(prepared_inputs, ['inputs_embeds', 'attention_mask']),
            };
        }
        let { last_hidden_state } = await encoderForward(this, model_inputs);

        // for classifier free guidance we need to add a 'null' input to our encoder hidden states
        if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {

            last_hidden_state = cat([
                last_hidden_state,
                full_like(last_hidden_state, 0.0),
            ], 0);

            if ('attention_mask' in model_inputs) {
                model_inputs['attention_mask'] = cat([
                    model_inputs['attention_mask'],
                    zeros_like(model_inputs['attention_mask']),
                ], 0);
            }

        } else if (model_inputs.decoder_input_ids) {
            // Ensure that the encoder outputs have the same batch size as the decoder inputs,
            // allowing for more efficient batched generation for single inputs
            const decoder_input_ids_batch_size = toI64Tensor(model_inputs.decoder_input_ids).dims[0];
            if (decoder_input_ids_batch_size !== last_hidden_state.dims[0]) {
                if (last_hidden_state.dims[0] !== 1) {
                    throw new Error(
                        `The encoder outputs have a different batch size (${last_hidden_state.dims[0]}) than the decoder inputs (${decoder_input_ids_batch_size}).`
                    )
                }
                last_hidden_state = cat(Array.from({ length: decoder_input_ids_batch_size }, () => last_hidden_state), 0);
            }
        }
        model_inputs['encoder_outputs'] = last_hidden_state;

        return model_inputs;
    }