_call()

in src/generation/logits_process.js [271:314]


    _call(input_ids, logits) {
        for (let i = 0; i < input_ids.length; ++i) {
            const batch_logits_data = /** @type {Float32Array} */(logits[i].data);

            // suppress <|notimestamps|> which is handled by without_timestamps
            batch_logits_data[this.no_timestamps_token_id] = -Infinity;

            if (input_ids[i].length === this.begin_index - 1) {
                batch_logits_data.fill(-Infinity);
                batch_logits_data[this.timestamp_begin] = 0;
                continue;
            }

            // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
            const seq = input_ids[i].slice(this.begin_index);
            const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
            const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;

            if (last_was_timestamp) {
                if (penultimate_was_timestamp) { // has to be non-timestamp
                    batch_logits_data.subarray(this.timestamp_begin).fill(-Infinity);
                } else { // cannot be normal text tokens
                    batch_logits_data.subarray(0, this.eos_token_id).fill(-Infinity);
                }
            }

            // apply the `max_initial_timestamp` option
            if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
                const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
                batch_logits_data.subarray(last_allowed + 1).fill(-Infinity);
            }

            // if sum of probability over timestamps is above any other token, sample timestamp
            const logprobs = log_softmax(batch_logits_data);
            const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
            const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];

            if (timestamp_logprob > max_text_token_logprob) {
                batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
            }
        }

        return logits;
    }