_extract_token_timestamps()

in src/models.js [3486:3585]


    _extract_token_timestamps(generate_outputs, alignment_heads, num_frames = null, time_precision = 0.02) {
        if (!generate_outputs.cross_attentions) {
            throw new Error(
                "Model outputs must contain cross attentions to extract timestamps. " +
                "This is most likely because the model was not exported with `output_attentions=True`."
            )
        }
        if (num_frames == null) {
            console.warn(
                "`num_frames` has not been set, meaning the entire audio will be analyzed. " +
                "This may lead to inaccurate token-level timestamps for short audios (< 30 seconds)."
            );
        }

        // @ts-expect-error TS2339
        let median_filter_width = this.config.median_filter_width;
        if (median_filter_width === undefined) {
            console.warn("Model config has no `median_filter_width`, using default value of 7.")
            median_filter_width = 7;
        }

        // TODO: Improve batch processing
        const batch = generate_outputs.cross_attentions;
        // Create a list with `decoder_layers` elements, each a tensor of shape
        // (batch size, attention_heads, output length, input length).
        // @ts-expect-error TS2339
        const cross_attentions = Array.from({ length: this.config.decoder_layers },
            // Concatenate the cross attentions for each layer across sequence length dimension.
            (_, i) => cat(batch.map(x => x[i]), 2)
        );

        const weights = stack(alignment_heads.map(([l, h]) => {
            if (l >= cross_attentions.length) {
                throw new Error(`Layer index ${l} is out of bounds for cross attentions (length ${cross_attentions.length}).`)
            }
            return num_frames
                ? cross_attentions[l].slice(null, h, null, [0, num_frames])
                : cross_attentions[l].slice(null, h);
        })).transpose(1, 0, 2, 3);

        const [std, calculatedMean] = std_mean(weights, -2, 0, true);

        // Normalize and smoothen the weights.
        const smoothedWeights = weights.clone(); // [1, 8, seqLength, 1500]

        for (let a = 0; a < smoothedWeights.dims[0]; ++a) {
            const aTensor = smoothedWeights[a]; // [8, seqLength, 1500]

            for (let b = 0; b < aTensor.dims[0]; ++b) {
                const bTensor = aTensor[b]; // [seqLength, 1500]

                const stdTensorData = std[a][b][0].data; // [1500]
                const meanTensorData = calculatedMean[a][b][0].data; // [1500]

                for (let c = 0; c < bTensor.dims[0]; ++c) {

                    let cTensorData = bTensor[c].data; // [1500]
                    for (let d = 0; d < cTensorData.length; ++d) {
                        cTensorData[d] = (cTensorData[d] - meanTensorData[d]) / stdTensorData[d]
                    }

                    // Apply median filter.
                    cTensorData.set(medianFilter(cTensorData, median_filter_width))
                }
            }
        }

        // Average the different cross-attention heads.
        const batchedMatrices = [mean(smoothedWeights, 1)];

        const timestampsShape = generate_outputs.sequences.dims;

        const timestamps = new Tensor(
            'float32',
            new Float32Array(timestampsShape[0] * timestampsShape[1]),
            timestampsShape
        );

        // Perform dynamic time warping on each element of the batch.
        for (let batch_idx = 0; batch_idx < timestampsShape[0]; ++batch_idx) {
            // NOTE: Since we run only one batch at a time, we can squeeze to get the same dimensions
            // as the python implementation
            const matrix = batchedMatrices[batch_idx].neg().squeeze_(0);
            const [text_indices, time_indices] = dynamic_time_warping(matrix.tolist());

            const diffs = Array.from({ length: text_indices.length - 1 }, (v, i) => text_indices[i + 1] - text_indices[i]);
            const jumps = mergeArrays([1], diffs).map(x => !!x); // convert to boolean

            const jump_times = [];
            for (let i = 0; i < jumps.length; ++i) {
                if (jumps[i]) {
                    // NOTE: No point in rounding here, since we set to Float32Array later
                    jump_times.push(time_indices[i] * time_precision);
                }
            }
            timestamps[batch_idx].data.set(jump_times, 1)
        }

        return timestamps;
    }