crates/ratchet-models/src/whisper/spectrogram.rs (29 lines of code) (raw):

//Adapted from: https://github.com/tanmayb123/OpenAI-Whisper-CoreML use ndarray::{Array1, Array2}; use ndarray_stats::QuantileExt; use num::complex::Complex; use ratchet::Tensor; use realfft::{RealFftPlanner, RealToComplex}; use std::f32::consts::PI; use std::sync::Arc; pub static SAMPLE_RATE: usize = 16000; pub static N_FFT: usize = 400; pub static HOP_LENGTH: usize = 160; pub static CHUNK_LENGTH: usize = 30; pub static N_AUDIO_CTX: usize = 1500; //same for all pub static N_SAMPLES: usize = SAMPLE_RATE * CHUNK_LENGTH; // 480000 pub static N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 pub static FFT_PAD: usize = N_FFT / 2; #[derive(Debug, thiserror::Error)] pub enum AudioError { #[error("Audio must be 30 seconds long (with stft padding): {0} != {1}")] InvalidLength(usize, usize), #[error("Invalid audio provided: {0}")] InvalidAudio(#[from] anyhow::Error), } pub struct SpectrogramGenerator { fft_plan: Arc<dyn RealToComplex<f32>>, hann_window: Array1<f32>, mels: Array2<f32>, } impl std::fmt::Debug for SpectrogramGenerator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SpectrogramGenerator").finish() } } impl SpectrogramGenerator { pub fn new(mels: Vec<f32>) -> Self { let mut planner = RealFftPlanner::new(); let n_mels = mels.len() / (N_FFT / 2 + 1); Self { fft_plan: planner.plan_fft_forward(N_FFT), hann_window: Self::hann_window(), mels: Array2::from_shape_vec((n_mels, N_FFT / 2 + 1), mels).unwrap(), } } fn hann_window() -> Array1<f32> { let window = (0..N_FFT) .map(|i| (i as f32 * 2.0 * PI) / N_FFT as f32) .map(|i| (1.0 - i.cos()) / 2.0) .collect::<Vec<_>>(); Array1::from(window) } fn fft(&self, audio: &[f32]) -> Vec<Complex<f32>> { let mut input = Array1::from_vec(audio.to_vec()); input *= &self.hann_window; let mut spectrum = self.fft_plan.make_output_vec(); self.fft_plan .process(input.as_slice_mut().unwrap(), &mut spectrum) .unwrap(); spectrum } fn mel_spectrogram(&self, audio: &[f32]) -> Tensor { let n_frames = (audio.len() - N_FFT) / HOP_LENGTH; let right_padding = N_SAMPLES + FFT_PAD; //padding is all 0s, so we can ignore it let mut spectrogram = Array2::<f32>::zeros((201, n_frames)); for i in (0..audio.len() - right_padding).step_by(HOP_LENGTH) { if i / HOP_LENGTH >= n_frames { break; } let fft = self.fft(&audio[i..i + N_FFT]); let spectrogram_col = fft.iter().map(|c| c.norm_sqr()).collect::<Array1<f32>>(); spectrogram .column_mut(i / HOP_LENGTH) .assign(&spectrogram_col); } let mut mel_spec = self.mels.dot(&spectrogram); mel_spec.mapv_inplace(|x| x.max(1e-10).log10()); let max = *mel_spec.max().unwrap(); mel_spec.mapv_inplace(|x| (x.max(max - 8.0) + 4.0) / 4.0); let expanded = mel_spec.insert_axis(ndarray::Axis(0)); Tensor::from(expanded.into_dyn()) } pub fn generate(&self, audio: Vec<f32>) -> Result<Tensor, AudioError> { if audio.is_empty() { return Err(AudioError::InvalidAudio(anyhow::anyhow!( "Audio must be non-empty" ))); } let padded = Self::pad_audio(audio, N_SAMPLES); Ok(self.mel_spectrogram(&padded)) } //The padding done by OAI is as follows: //1. First explicitly pad with (CHUNK_LENGTH * SAMPLE_RATE) (480,000) zeros //2. Then perform a reflection padding of FFT_PAD (200) samples on each side // This must be done with care, because we have already performed the explicit padding // the pre-padding will contain non-zero values, but the post-padding must be zero pub fn pad_audio(audio: Vec<f32>, padding: usize) -> Vec<f32> { let padded_len = FFT_PAD + audio.len() + padding + FFT_PAD; let mut padded_samples = vec![0.0; padded_len]; let mut reflect_padding = vec![0.0; FFT_PAD]; for i in 0..FFT_PAD { reflect_padding[i] = audio[FFT_PAD - i]; } padded_samples[0..FFT_PAD].copy_from_slice(&reflect_padding); padded_samples[FFT_PAD..(FFT_PAD + audio.len())].copy_from_slice(&audio); padded_samples } } #[cfg(all(test, feature = "pyo3", not(target_arch = "wasm32")))] mod tests { use super::SpectrogramGenerator; use hf_hub::api::sync::Api; use ratchet::test_util::run_py_prg; use ratchet::DType; use std::path::PathBuf; const MAX_DIFF: f32 = 5e-5; pub fn load_npy(path: PathBuf) -> Vec<f32> { let bytes = std::fs::read(path).unwrap(); npyz::NpyFile::new(&bytes[..]).unwrap().into_vec().unwrap() } fn load_sample(path: PathBuf) -> Vec<f32> { let mut reader = hound::WavReader::open(path).unwrap(); reader .samples::<i16>() .map(|x| x.unwrap() as f32 / 32768.0) .collect::<Vec<_>>() } #[test] fn spectrogram_matches() { let api = Api::new().unwrap(); let repo = api.dataset("FL33TW00D-HF/ratchet-util".to_string()); let gb0 = repo.get("erwin_jp.wav").unwrap(); let mels = repo.get("mel_filters_128.npy").unwrap(); let prg = format!( r#" import whisper import numpy as np def ground_truth(): audio = whisper.load_audio("{}") return whisper.log_mel_spectrogram(audio, n_mels=128, padding=480000).numpy()[np.newaxis] "#, gb0.to_str().unwrap() ); let ground_truth = run_py_prg(prg.to_string(), &[], &[], DType::F32).unwrap(); let generator = SpectrogramGenerator::new(load_npy(mels)); let result = generator.generate(load_sample(gb0)).unwrap(); ground_truth.all_close(&result, MAX_DIFF, MAX_DIFF).unwrap(); } }