in candle-transformers/src/models/whisper/audio.rs [90:168]
fn log_mel_spectrogram_w<T: Float>(
ith: usize,
hann: &[T],
samples: &[T],
filters: &[T],
fft_size: usize,
fft_step: usize,
speed_up: bool,
n_len: usize,
n_mel: usize,
n_threads: usize,
) -> Vec<T> {
let n_fft = if speed_up {
1 + fft_size / 4
} else {
1 + fft_size / 2
};
let zero = T::zero();
let half = T::from(0.5).unwrap();
let mut fft_in = vec![zero; fft_size];
let mut mel = vec![zero; n_len * n_mel];
let n_samples = samples.len();
let end = std::cmp::min(n_samples / fft_step + 1, n_len);
for i in (ith..end).step_by(n_threads) {
let offset = i * fft_step;
// apply Hanning window
for j in 0..std::cmp::min(fft_size, n_samples - offset) {
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if n_samples - offset < fft_size {
fft_in[n_samples - offset..].fill(zero);
}
// FFT
let mut fft_out: Vec<T> = fft(&fft_in);
// Calculate modulus^2 of complex numbers
for j in 0..fft_size {
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
}
for j in 1..fft_size / 2 {
let v = fft_out[fft_size - j];
fft_out[j] += v;
}
if speed_up {
// scale down in the frequency domain results in a speed up in the time domain
for j in 0..n_fft {
fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
}
}
// mel spectrogram
for j in 0..n_mel {
let mut sum = zero;
let mut k = 0;
// Unroll loop
while k < n_fft.saturating_sub(3) {
sum += fft_out[k] * filters[j * n_fft + k]
+ fft_out[k + 1] * filters[j * n_fft + k + 1]
+ fft_out[k + 2] * filters[j * n_fft + k + 2]
+ fft_out[k + 3] * filters[j * n_fft + k + 3];
k += 4;
}
// Handle remainder
while k < n_fft {
sum += fft_out[k] * filters[j * n_fft + k];
k += 1;
}
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
}
}
mel
}