in src/models.js [1740:1949]
async generate({
inputs = null,
generation_config = null,
logits_processor = null,
stopping_criteria = null,
streamer = null,
// inputs_attention_mask = null,
...kwargs
}) {
this._validate_model_class();
// Update generation config with defaults and kwargs
generation_config = this._prepare_generation_config(generation_config, kwargs);
// 3. Define model inputs
let { inputs_tensor, model_inputs, model_input_name } = this._prepare_model_inputs({
inputs,
model_kwargs: kwargs,
});
const is_encoder_decoder = this.config.is_encoder_decoder;
// 4. Define other model kwargs
if (!is_encoder_decoder) {
// decoder-only models should use left-padding for generation
} else if (!('encoder_outputs' in model_inputs)) {
// if model is encoder decoder encoder_outputs are created
// and added to `model_kwargs`
model_inputs = await this._prepare_encoder_decoder_kwargs_for_generation(
{ inputs_tensor, model_inputs, model_input_name, generation_config }
)
}
// 5. Prepare `input_ids` which will be used for auto-regressive generation
// TODO: Update to align with HF transformers' implementation
let input_ids;
if (is_encoder_decoder) {
// Generating from the encoder outputs
({ input_ids, model_inputs } = this._prepare_decoder_input_ids_for_generation({
batch_size: model_inputs[model_input_name].dims.at(0),
model_input_name,
model_kwargs: model_inputs,
decoder_start_token_id: generation_config.decoder_start_token_id,
bos_token_id: generation_config.bos_token_id,
generation_config,
}));
} else {
input_ids = model_inputs[model_input_name]
}
// 6. Prepare `max_length` depending on other stopping criteria.
let input_ids_length = input_ids.dims.at(-1);
if (generation_config.max_new_tokens !== null) {
generation_config.max_length = input_ids_length + generation_config.max_new_tokens;
}
// input_ids_length = model_inputs[model_input_name].dims.at(1);
// // inputs instanceof Tensor ? : inputs.length;
// // decoder-only
// if (input_ids_length === 0) {
// throw Error("Must supply a non-empty array of input token ids.")
// }
// let decoder_input_ids =
// generation_config.decoder_input_ids
// ?? generation_config.decoder_start_token_id
// ?? generation_config.bos_token_id
// ?? generation_config.eos_token_id;
// Update logits processor
// 8. prepare distribution pre_processing samplers
const prepared_logits_processor = this._get_logits_processor(
generation_config,
input_ids_length,
logits_processor,
)
// 9. prepare stopping criteria
const prepared_stopping_criteria = this._get_stopping_criteria(
generation_config, stopping_criteria
)
// /** @type {number[]} */
// let eos_token_ids = generation_config.eos_token_id;
// if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
// eos_token_ids = [eos_token_ids];
// }
const numInputs = model_inputs[model_input_name].dims.at(0);
// TODO:
// done is a list of booleans to keep track of which inputs are done
// const done = new Array(numInputs).fill(false);
// For efficiency purposes, we remove completed rows from model_inputs
// when the beam is complete, and we keep track of the row index
// const rowIndexToBatchIndex = new Map();
const sampler = LogitsSampler.getSampler(generation_config);
// TODO make > numInputs
const scores = new Array(numInputs).fill(0);
/** @type {bigint[][]} */
const all_input_ids = input_ids.tolist();
if (streamer) {
streamer.put(all_input_ids);
}
// const all_generated_input_ids = Array.from({ length: numInputs }, () => []);
// NOTE: For now, we don't support spawning new beams
// TODO: when we do, we simply copy past key values and accumulate into single large tensor
////////////////////////////////////////////////////
// Generic search which handles 4 generation modes:
// - GenerationMode.GREEDY_SEARCH
// - GenerationMode.SAMPLE
// - GenerationMode.BEAM_SEARCH
// - GenerationMode.BEAM_SAMPLE
////////////////////////////////////////////////////
let outputs;
let attentions = {};
while (true) {
// prepare model inputs
model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs, generation_config);
outputs = await this.forward(model_inputs);
if (generation_config.output_attentions && generation_config.return_dict_in_generate) {
// Get attentions if they are present
const token_attentions = this.getAttentions(outputs);
for (const key in token_attentions) {
if (!(key in attentions)) {
attentions[key] = [];
}
attentions[key].push(token_attentions[key]);
}
}
// Logits are of the form [batch_size, out_seq_length, vocab_size]
// In most cases, this will be [batch_size, 1, vocab_size]
// So, we select the last token's logits:
// (equivalent to `logits = outputs.logits[:, -1, :]`)
const logits = outputs.logits.slice(null, -1, null);
const next_tokens_scores = prepared_logits_processor(all_input_ids, logits);
/** @type {[bigint][]} */
const generated_input_ids = [];
// const new_kv_cache = [];// NOTE: Only used for beam search when concatenating new kv
// Loop over each batch
for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) {
const logs = next_tokens_scores[batch_idx];
const sampledTokens = await sampler(logs);
for (const [newTokenId, logProb] of sampledTokens) {
const bigint = BigInt(newTokenId);
// TODO: If branching, use previous beam as a starting point
// update generated ids, model inputs, and length for next step
scores[batch_idx] += logProb;
all_input_ids[batch_idx].push(bigint);
generated_input_ids.push([bigint]);
// TODO: Support beam search
break;
}
}
if (streamer) {
streamer.put(generated_input_ids);
}
const stop = prepared_stopping_criteria(all_input_ids);
if (stop.every(x => x)) {
break;
}
model_inputs = this._update_model_kwargs_for_generation({
generated_input_ids, outputs, model_inputs, is_encoder_decoder,
});
}
if (streamer) {
streamer.end();
}
// Retrieve and dispose all final past key values (including encoder attentions)
const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true);
// TODO: ensure all_input_ids is padded correctly...
const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]);
if (generation_config.return_dict_in_generate) {
return {
sequences,
past_key_values,
...attentions,
// TODO:
// scores,
// logits,
}
} else {
// Dispose all remaining tensors
for (const tensor of Object.values(outputs)) {
if (tensor.location === 'gpu-buffer') {
tensor.dispose();
}
}
return sequences;
}
}