async generate()

in src/models.js [1740:1949]


    async generate({
        inputs = null,
        generation_config = null,
        logits_processor = null,
        stopping_criteria = null,
        streamer = null,

        // inputs_attention_mask = null,
        ...kwargs
    }) {
        this._validate_model_class();

        // Update generation config with defaults and kwargs
        generation_config = this._prepare_generation_config(generation_config, kwargs);

        // 3. Define model inputs
        let { inputs_tensor, model_inputs, model_input_name } = this._prepare_model_inputs({
            inputs,
            model_kwargs: kwargs,
        });

        const is_encoder_decoder = this.config.is_encoder_decoder;

        // 4. Define other model kwargs
        if (!is_encoder_decoder) {
            // decoder-only models should use left-padding for generation
        } else if (!('encoder_outputs' in model_inputs)) {
            // if model is encoder decoder encoder_outputs are created
            // and added to `model_kwargs`
            model_inputs = await this._prepare_encoder_decoder_kwargs_for_generation(
                { inputs_tensor, model_inputs, model_input_name, generation_config }
            )
        }

        // 5. Prepare `input_ids` which will be used for auto-regressive generation
        // TODO: Update to align with HF transformers' implementation
        let input_ids;
        if (is_encoder_decoder) {
            // Generating from the encoder outputs
            ({ input_ids, model_inputs } = this._prepare_decoder_input_ids_for_generation({
                batch_size: model_inputs[model_input_name].dims.at(0),
                model_input_name,
                model_kwargs: model_inputs,
                decoder_start_token_id: generation_config.decoder_start_token_id,
                bos_token_id: generation_config.bos_token_id,
                generation_config,
            }));
        } else {
            input_ids = model_inputs[model_input_name]
        }

        // 6. Prepare `max_length` depending on other stopping criteria.
        let input_ids_length = input_ids.dims.at(-1);

        if (generation_config.max_new_tokens !== null) {
            generation_config.max_length = input_ids_length + generation_config.max_new_tokens;
        }

        // input_ids_length = model_inputs[model_input_name].dims.at(1);
        // // inputs instanceof Tensor ?  : inputs.length;

        // // decoder-only
        // if (input_ids_length === 0) {
        //     throw Error("Must supply a non-empty array of input token ids.")
        // }

        // let decoder_input_ids =
        // generation_config.decoder_input_ids
        // ?? generation_config.decoder_start_token_id
        // ?? generation_config.bos_token_id
        // ?? generation_config.eos_token_id;

        // Update logits processor
        // 8. prepare distribution pre_processing samplers
        const prepared_logits_processor = this._get_logits_processor(
            generation_config,
            input_ids_length,
            logits_processor,
        )

        // 9. prepare stopping criteria
        const prepared_stopping_criteria = this._get_stopping_criteria(
            generation_config, stopping_criteria
        )

        // /** @type {number[]} */
        // let eos_token_ids = generation_config.eos_token_id;
        // if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
        //     eos_token_ids = [eos_token_ids];
        // }

        const numInputs = model_inputs[model_input_name].dims.at(0);

        // TODO:
        // done is a list of booleans to keep track of which inputs are done
        // const done = new Array(numInputs).fill(false);
        // For efficiency purposes, we remove completed rows from model_inputs
        // when the beam is complete, and we keep track of the row index
        // const rowIndexToBatchIndex = new Map();

        const sampler = LogitsSampler.getSampler(generation_config);

        // TODO make > numInputs
        const scores = new Array(numInputs).fill(0);
        /** @type {bigint[][]} */
        const all_input_ids = input_ids.tolist();
        if (streamer) {
            streamer.put(all_input_ids);
        }
        // const all_generated_input_ids = Array.from({ length: numInputs }, () => []);

        // NOTE: For now, we don't support spawning new beams
        // TODO: when we do, we simply copy past key values and accumulate into single large tensor

        ////////////////////////////////////////////////////
        // Generic search which handles 4 generation modes:
        // - GenerationMode.GREEDY_SEARCH
        // - GenerationMode.SAMPLE
        // - GenerationMode.BEAM_SEARCH
        // - GenerationMode.BEAM_SAMPLE
        ////////////////////////////////////////////////////
        let outputs;
        let attentions = {};
        while (true) {
            // prepare model inputs
            model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs, generation_config);
            outputs = await this.forward(model_inputs);

            if (generation_config.output_attentions && generation_config.return_dict_in_generate) {
                // Get attentions if they are present
                const token_attentions = this.getAttentions(outputs);
                for (const key in token_attentions) {
                    if (!(key in attentions)) {
                        attentions[key] = [];
                    }
                    attentions[key].push(token_attentions[key]);
                }
            }

            // Logits are of the form [batch_size, out_seq_length, vocab_size]
            // In most cases, this will be [batch_size, 1, vocab_size]
            // So, we select the last token's logits:
            // (equivalent to `logits = outputs.logits[:, -1, :]`)
            const logits = outputs.logits.slice(null, -1, null);

            const next_tokens_scores = prepared_logits_processor(all_input_ids, logits);

            /** @type {[bigint][]} */
            const generated_input_ids = [];
            // const new_kv_cache = [];// NOTE: Only used for beam search when concatenating new kv
            // Loop over each batch
            for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) {
                const logs = next_tokens_scores[batch_idx];

                const sampledTokens = await sampler(logs);
                for (const [newTokenId, logProb] of sampledTokens) {
                    const bigint = BigInt(newTokenId);
                    // TODO: If branching, use previous beam as a starting point
                    // update generated ids, model inputs, and length for next step
                    scores[batch_idx] += logProb;
                    all_input_ids[batch_idx].push(bigint);
                    generated_input_ids.push([bigint]);

                    // TODO: Support beam search
                    break;
                }
            }
            if (streamer) {
                streamer.put(generated_input_ids);
            }

            const stop = prepared_stopping_criteria(all_input_ids);
            if (stop.every(x => x)) {
                break;
            }

            model_inputs = this._update_model_kwargs_for_generation({
                generated_input_ids, outputs, model_inputs, is_encoder_decoder,
            });
        }

        if (streamer) {
            streamer.end();
        }

        // Retrieve and dispose all final past key values (including encoder attentions)
        const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true);

        // TODO: ensure all_input_ids is padded correctly...
        const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]);

        if (generation_config.return_dict_in_generate) {
            return {
                sequences,
                past_key_values,
                ...attentions,
                // TODO:
                // scores,
                // logits,
            }
        } else {
            // Dispose all remaining tensors
            for (const tensor of Object.values(outputs)) {
                if (tensor.location === 'gpu-buffer') {
                    tensor.dispose();
                }
            }
            return sequences;
        }
    }