onedrop_engine/
fft.rs

1//! FFT-based audio analysis for advanced frequency detection.
2
3use std::f32::consts::PI;
4
5/// FFT analyzer for audio frequency analysis.
6pub struct FFTAnalyzer {
7    /// FFT size (power of 2)
8    fft_size: usize,
9
10    /// Window function (Hann window)
11    window: Vec<f32>,
12
13    /// FFT buffer (real and imaginary parts)
14    fft_buffer: Vec<f32>,
15
16    /// Frequency bins
17    bins: Vec<f32>,
18
19    /// Sample rate
20    sample_rate: f32,
21}
22
23/// Error type for FFT operations
24#[derive(Debug, Clone)]
25pub struct FFTError {
26    pub message: String,
27}
28
29impl std::fmt::Display for FFTError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "FFT Error: {}", self.message)
32    }
33}
34
35impl std::error::Error for FFTError {}
36
37impl FFTAnalyzer {
38    /// Create a new FFT analyzer.
39    /// Returns None if fft_size is not a power of 2.
40    pub fn new(fft_size: usize, sample_rate: f32) -> Option<Self> {
41        if !fft_size.is_power_of_two() || fft_size == 0 {
42            log::warn!("FFT size must be a power of 2, got {}", fft_size);
43            return None;
44        }
45
46        // Create Hann window
47        let window: Vec<f32> = (0..fft_size)
48            .map(|i| {
49                let x = i as f32 / fft_size as f32;
50                0.5 * (1.0 - (2.0 * PI * x).cos())
51            })
52            .collect();
53
54        Some(Self {
55            fft_size,
56            window,
57            fft_buffer: vec![0.0; fft_size * 2], // Real + imaginary
58            bins: vec![0.0; fft_size / 2],
59            sample_rate,
60        })
61    }
62
63    /// Create a new FFT analyzer, panicking on invalid input.
64    /// Use `new` for fallible construction.
65    pub fn new_or_default(fft_size: usize, sample_rate: f32) -> Self {
66        Self::new(fft_size, sample_rate).unwrap_or_else(|| {
67            log::warn!("Invalid FFT size {}, using default 256", fft_size);
68            Self::new(256, sample_rate).expect("256 is always valid")
69        })
70    }
71
72    /// Analyze audio samples and return frequency bins.
73    pub fn analyze(&mut self, samples: &[f32]) -> &[f32] {
74        // Ensure we have enough samples
75        let num_samples = samples.len().min(self.fft_size);
76
77        // Apply window and copy to FFT buffer
78        for (i, (sample, window)) in samples
79            .iter()
80            .zip(self.window.iter())
81            .take(num_samples)
82            .enumerate()
83        {
84            self.fft_buffer[i * 2] = sample * window; // Real
85            self.fft_buffer[i * 2 + 1] = 0.0; // Imaginary
86        }
87
88        // Zero-pad if necessary
89        for i in num_samples..self.fft_size {
90            self.fft_buffer[i * 2] = 0.0;
91            self.fft_buffer[i * 2 + 1] = 0.0;
92        }
93
94        // Perform FFT (simple implementation)
95        self.fft_inplace();
96
97        // Calculate magnitude spectrum
98        for i in 0..self.bins.len() {
99            let real = self.fft_buffer[i * 2];
100            let imag = self.fft_buffer[i * 2 + 1];
101            self.bins[i] = (real * real + imag * imag).sqrt() / self.fft_size as f32;
102        }
103
104        &self.bins
105    }
106
107    /// Get bass level (20-250 Hz).
108    pub fn get_bass(&self) -> f32 {
109        self.get_frequency_range(20.0, 250.0)
110    }
111
112    /// Get mid level (250-2000 Hz).
113    pub fn get_mid(&self) -> f32 {
114        self.get_frequency_range(250.0, 2000.0)
115    }
116
117    /// Get treble level (2000-20000 Hz).
118    pub fn get_treble(&self) -> f32 {
119        self.get_frequency_range(2000.0, 20000.0)
120    }
121
122    /// Get energy in a frequency range.
123    fn get_frequency_range(&self, min_freq: f32, max_freq: f32) -> f32 {
124        let bin_width = self.sample_rate / self.fft_size as f32;
125        let min_bin = (min_freq / bin_width) as usize;
126        let max_bin = ((max_freq / bin_width) as usize).min(self.bins.len());
127
128        if min_bin >= max_bin {
129            return 0.0;
130        }
131
132        let sum: f32 = self.bins[min_bin..max_bin].iter().sum();
133        sum / (max_bin - min_bin) as f32
134    }
135
136    /// Simple in-place FFT (Cooley-Tukey algorithm).
137    fn fft_inplace(&mut self) {
138        let n = self.fft_size;
139
140        // Bit-reversal permutation
141        let mut j = 0;
142        for i in 0..n {
143            if i < j {
144                self.fft_buffer.swap(i * 2, j * 2);
145                self.fft_buffer.swap(i * 2 + 1, j * 2 + 1);
146            }
147
148            let mut m = n / 2;
149            while m >= 1 && j >= m {
150                j -= m;
151                m /= 2;
152            }
153            j += m;
154        }
155
156        // FFT computation
157        let mut len = 2;
158        while len <= n {
159            let angle = -2.0 * PI / len as f32;
160            let wlen_real = angle.cos();
161            let wlen_imag = angle.sin();
162
163            let mut i = 0;
164            while i < n {
165                let mut w_real = 1.0;
166                let mut w_imag = 0.0;
167
168                for j in 0..len / 2 {
169                    let u_idx = (i + j) * 2;
170                    let v_idx = (i + j + len / 2) * 2;
171
172                    let u_real = self.fft_buffer[u_idx];
173                    let u_imag = self.fft_buffer[u_idx + 1];
174                    let v_real = self.fft_buffer[v_idx];
175                    let v_imag = self.fft_buffer[v_idx + 1];
176
177                    let t_real = w_real * v_real - w_imag * v_imag;
178                    let t_imag = w_real * v_imag + w_imag * v_real;
179
180                    self.fft_buffer[u_idx] = u_real + t_real;
181                    self.fft_buffer[u_idx + 1] = u_imag + t_imag;
182                    self.fft_buffer[v_idx] = u_real - t_real;
183                    self.fft_buffer[v_idx + 1] = u_imag - t_imag;
184
185                    let w_temp = w_real;
186                    w_real = w_real * wlen_real - w_imag * wlen_imag;
187                    w_imag = w_temp * wlen_imag + w_imag * wlen_real;
188                }
189
190                i += len;
191            }
192
193            len *= 2;
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_fft_analyzer() {
204        let mut analyzer = FFTAnalyzer::new(256, 44100.0).expect("256 is valid FFT size");
205
206        // Generate a simple sine wave at 440 Hz (A4)
207        let samples: Vec<f32> = (0..256)
208            .map(|i| {
209                let t = i as f32 / 44100.0;
210                (2.0 * PI * 440.0 * t).sin()
211            })
212            .collect();
213
214        let bins = analyzer.analyze(&samples);
215
216        // Check that we got some output
217        assert!(!bins.is_empty());
218
219        // The peak should be around 440 Hz (handle NaN gracefully)
220        let peak_bin = bins
221            .iter()
222            .enumerate()
223            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
224            .map(|(i, _)| i)
225            .unwrap();
226
227        let peak_freq = peak_bin as f32 * 44100.0 / 256.0;
228
229        // Should be close to 440 Hz (within 200 Hz tolerance)
230        assert!((peak_freq - 440.0).abs() < 200.0);
231    }
232
233    #[test]
234    fn test_frequency_ranges() {
235        let mut analyzer = FFTAnalyzer::new(512, 44100.0).expect("512 is valid FFT size");
236
237        // Generate white noise
238        let samples: Vec<f32> = (0..512).map(|i| (i as f32 * 0.1).sin()).collect();
239
240        analyzer.analyze(&samples);
241
242        let bass = analyzer.get_bass();
243        let mid = analyzer.get_mid();
244        let treble = analyzer.get_treble();
245
246        // All should be non-negative
247        assert!(bass >= 0.0);
248        assert!(mid >= 0.0);
249        assert!(treble >= 0.0);
250    }
251
252    #[test]
253    fn test_invalid_fft_size() {
254        // Invalid sizes should return None
255        assert!(FFTAnalyzer::new(0, 44100.0).is_none());
256        assert!(FFTAnalyzer::new(100, 44100.0).is_none()); // Not power of 2
257        assert!(FFTAnalyzer::new(3, 44100.0).is_none()); // Not power of 2
258
259        // Valid sizes should work
260        assert!(FFTAnalyzer::new(256, 44100.0).is_some());
261        assert!(FFTAnalyzer::new(512, 44100.0).is_some());
262        assert!(FFTAnalyzer::new(1024, 44100.0).is_some());
263    }
264}