in src/models.js [3486:3585]
_extract_token_timestamps(generate_outputs, alignment_heads, num_frames = null, time_precision = 0.02) {
if (!generate_outputs.cross_attentions) {
throw new Error(
"Model outputs must contain cross attentions to extract timestamps. " +
"This is most likely because the model was not exported with `output_attentions=True`."
)
}
if (num_frames == null) {
console.warn(
"`num_frames` has not been set, meaning the entire audio will be analyzed. " +
"This may lead to inaccurate token-level timestamps for short audios (< 30 seconds)."
);
}
// @ts-expect-error TS2339
let median_filter_width = this.config.median_filter_width;
if (median_filter_width === undefined) {
console.warn("Model config has no `median_filter_width`, using default value of 7.")
median_filter_width = 7;
}
// TODO: Improve batch processing
const batch = generate_outputs.cross_attentions;
// Create a list with `decoder_layers` elements, each a tensor of shape
// (batch size, attention_heads, output length, input length).
// @ts-expect-error TS2339
const cross_attentions = Array.from({ length: this.config.decoder_layers },
// Concatenate the cross attentions for each layer across sequence length dimension.
(_, i) => cat(batch.map(x => x[i]), 2)
);
const weights = stack(alignment_heads.map(([l, h]) => {
if (l >= cross_attentions.length) {
throw new Error(`Layer index ${l} is out of bounds for cross attentions (length ${cross_attentions.length}).`)
}
return num_frames
? cross_attentions[l].slice(null, h, null, [0, num_frames])
: cross_attentions[l].slice(null, h);
})).transpose(1, 0, 2, 3);
const [std, calculatedMean] = std_mean(weights, -2, 0, true);
// Normalize and smoothen the weights.
const smoothedWeights = weights.clone(); // [1, 8, seqLength, 1500]
for (let a = 0; a < smoothedWeights.dims[0]; ++a) {
const aTensor = smoothedWeights[a]; // [8, seqLength, 1500]
for (let b = 0; b < aTensor.dims[0]; ++b) {
const bTensor = aTensor[b]; // [seqLength, 1500]
const stdTensorData = std[a][b][0].data; // [1500]
const meanTensorData = calculatedMean[a][b][0].data; // [1500]
for (let c = 0; c < bTensor.dims[0]; ++c) {
let cTensorData = bTensor[c].data; // [1500]
for (let d = 0; d < cTensorData.length; ++d) {
cTensorData[d] = (cTensorData[d] - meanTensorData[d]) / stdTensorData[d]
}
// Apply median filter.
cTensorData.set(medianFilter(cTensorData, median_filter_width))
}
}
}
// Average the different cross-attention heads.
const batchedMatrices = [mean(smoothedWeights, 1)];
const timestampsShape = generate_outputs.sequences.dims;
const timestamps = new Tensor(
'float32',
new Float32Array(timestampsShape[0] * timestampsShape[1]),
timestampsShape
);
// Perform dynamic time warping on each element of the batch.
for (let batch_idx = 0; batch_idx < timestampsShape[0]; ++batch_idx) {
// NOTE: Since we run only one batch at a time, we can squeeze to get the same dimensions
// as the python implementation
const matrix = batchedMatrices[batch_idx].neg().squeeze_(0);
const [text_indices, time_indices] = dynamic_time_warping(matrix.tolist());
const diffs = Array.from({ length: text_indices.length - 1 }, (v, i) => text_indices[i + 1] - text_indices[i]);
const jumps = mergeArrays([1], diffs).map(x => !!x); // convert to boolean
const jump_times = [];
for (let i = 0; i < jumps.length; ++i) {
if (jumps[i]) {
// NOTE: No point in rounding here, since we set to Float32Array later
jump_times.push(time_indices[i] * time_precision);
}
}
timestamps[batch_idx].data.set(jump_times, 1)
}
return timestamps;
}