function decoder_prepare_inputs_for_generation()

in src/models.js [890:936]


function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
    const past_length = model_inputs.past_key_values
        ? Object.values(model_inputs.past_key_values)[0].dims.at(-2)
        : 0;

    if (!model_inputs.attention_mask) {
        // If the attention mask is not provided, we attempt to infer based on provided inputs
        let dims;
        for (const key of ['input_ids', 'inputs_embeds', 'position_ids']) {
            if (model_inputs[key]) {
                dims = model_inputs[key].dims;
                break;
            }
        }
        if (!dims) {
            throw new Error("attention_mask is not provided, and unable to infer its shape from model inputs.");
        }
        model_inputs.attention_mask = ones([dims[0], past_length + dims[1]]);
    }

    if (model_inputs.past_key_values) {
        const { input_ids, attention_mask } = model_inputs;

        // Keep only the unprocessed tokens:
        // 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
        // some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
        // input)
        if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) {
            // NOTE: not needed since we only pass the generated tokens to the next forward pass
            // const offset = -(attention_mask.dims[1] - past_length);
            // model_inputs.input_ids = input_ids.slice(null, [offset, null]);
        }
        // 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
        // We can discard input_ids based on the past_length.
        else if (past_length < input_ids.dims[1]) {
            // NOTE: Required for phi models.
            // See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information.
            model_inputs.input_ids = input_ids.slice(null, [past_length, null]);
        }
        // 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
        else {

        }
    }

    return model_inputs;
}