in src/generation/logits_process.js [271:314]
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
// suppress <|notimestamps|> which is handled by without_timestamps
batch_logits_data[this.no_timestamps_token_id] = -Infinity;
if (input_ids[i].length === this.begin_index - 1) {
batch_logits_data.fill(-Infinity);
batch_logits_data[this.timestamp_begin] = 0;
continue;
}
// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
const seq = input_ids[i].slice(this.begin_index);
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
if (last_was_timestamp) {
if (penultimate_was_timestamp) { // has to be non-timestamp
batch_logits_data.subarray(this.timestamp_begin).fill(-Infinity);
} else { // cannot be normal text tokens
batch_logits_data.subarray(0, this.eos_token_id).fill(-Infinity);
}
}
// apply the `max_initial_timestamp` option
if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
batch_logits_data.subarray(last_allowed + 1).fill(-Infinity);
}
// if sum of probability over timestamps is above any other token, sample timestamp
const logprobs = log_softmax(batch_logits_data);
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];
if (timestamp_logprob > max_text_token_logprob) {
batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
}
}
return logits;
}