1use std::f32::consts::PI;
4
5pub struct FFTAnalyzer {
7 fft_size: usize,
9
10 window: Vec<f32>,
12
13 fft_buffer: Vec<f32>,
15
16 bins: Vec<f32>,
18
19 sample_rate: f32,
21}
22
23#[derive(Debug, Clone)]
25pub struct FFTError {
26 pub message: String,
27}
28
29impl std::fmt::Display for FFTError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "FFT Error: {}", self.message)
32 }
33}
34
35impl std::error::Error for FFTError {}
36
37impl FFTAnalyzer {
38 pub fn new(fft_size: usize, sample_rate: f32) -> Option<Self> {
41 if !fft_size.is_power_of_two() || fft_size == 0 {
42 log::warn!("FFT size must be a power of 2, got {}", fft_size);
43 return None;
44 }
45
46 let window: Vec<f32> = (0..fft_size)
48 .map(|i| {
49 let x = i as f32 / fft_size as f32;
50 0.5 * (1.0 - (2.0 * PI * x).cos())
51 })
52 .collect();
53
54 Some(Self {
55 fft_size,
56 window,
57 fft_buffer: vec![0.0; fft_size * 2], bins: vec![0.0; fft_size / 2],
59 sample_rate,
60 })
61 }
62
63 pub fn new_or_default(fft_size: usize, sample_rate: f32) -> Self {
66 Self::new(fft_size, sample_rate).unwrap_or_else(|| {
67 log::warn!("Invalid FFT size {}, using default 256", fft_size);
68 Self::new(256, sample_rate).expect("256 is always valid")
69 })
70 }
71
72 pub fn analyze(&mut self, samples: &[f32]) -> &[f32] {
74 let num_samples = samples.len().min(self.fft_size);
76
77 for (i, (sample, window)) in samples
79 .iter()
80 .zip(self.window.iter())
81 .take(num_samples)
82 .enumerate()
83 {
84 self.fft_buffer[i * 2] = sample * window; self.fft_buffer[i * 2 + 1] = 0.0; }
87
88 for i in num_samples..self.fft_size {
90 self.fft_buffer[i * 2] = 0.0;
91 self.fft_buffer[i * 2 + 1] = 0.0;
92 }
93
94 self.fft_inplace();
96
97 for i in 0..self.bins.len() {
99 let real = self.fft_buffer[i * 2];
100 let imag = self.fft_buffer[i * 2 + 1];
101 self.bins[i] = (real * real + imag * imag).sqrt() / self.fft_size as f32;
102 }
103
104 &self.bins
105 }
106
107 pub fn get_bass(&self) -> f32 {
109 self.get_frequency_range(20.0, 250.0)
110 }
111
112 pub fn get_mid(&self) -> f32 {
114 self.get_frequency_range(250.0, 2000.0)
115 }
116
117 pub fn get_treble(&self) -> f32 {
119 self.get_frequency_range(2000.0, 20000.0)
120 }
121
122 fn get_frequency_range(&self, min_freq: f32, max_freq: f32) -> f32 {
124 let bin_width = self.sample_rate / self.fft_size as f32;
125 let min_bin = (min_freq / bin_width) as usize;
126 let max_bin = ((max_freq / bin_width) as usize).min(self.bins.len());
127
128 if min_bin >= max_bin {
129 return 0.0;
130 }
131
132 let sum: f32 = self.bins[min_bin..max_bin].iter().sum();
133 sum / (max_bin - min_bin) as f32
134 }
135
136 fn fft_inplace(&mut self) {
138 let n = self.fft_size;
139
140 let mut j = 0;
142 for i in 0..n {
143 if i < j {
144 self.fft_buffer.swap(i * 2, j * 2);
145 self.fft_buffer.swap(i * 2 + 1, j * 2 + 1);
146 }
147
148 let mut m = n / 2;
149 while m >= 1 && j >= m {
150 j -= m;
151 m /= 2;
152 }
153 j += m;
154 }
155
156 let mut len = 2;
158 while len <= n {
159 let angle = -2.0 * PI / len as f32;
160 let wlen_real = angle.cos();
161 let wlen_imag = angle.sin();
162
163 let mut i = 0;
164 while i < n {
165 let mut w_real = 1.0;
166 let mut w_imag = 0.0;
167
168 for j in 0..len / 2 {
169 let u_idx = (i + j) * 2;
170 let v_idx = (i + j + len / 2) * 2;
171
172 let u_real = self.fft_buffer[u_idx];
173 let u_imag = self.fft_buffer[u_idx + 1];
174 let v_real = self.fft_buffer[v_idx];
175 let v_imag = self.fft_buffer[v_idx + 1];
176
177 let t_real = w_real * v_real - w_imag * v_imag;
178 let t_imag = w_real * v_imag + w_imag * v_real;
179
180 self.fft_buffer[u_idx] = u_real + t_real;
181 self.fft_buffer[u_idx + 1] = u_imag + t_imag;
182 self.fft_buffer[v_idx] = u_real - t_real;
183 self.fft_buffer[v_idx + 1] = u_imag - t_imag;
184
185 let w_temp = w_real;
186 w_real = w_real * wlen_real - w_imag * wlen_imag;
187 w_imag = w_temp * wlen_imag + w_imag * wlen_real;
188 }
189
190 i += len;
191 }
192
193 len *= 2;
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_fft_analyzer() {
204 let mut analyzer = FFTAnalyzer::new(256, 44100.0).expect("256 is valid FFT size");
205
206 let samples: Vec<f32> = (0..256)
208 .map(|i| {
209 let t = i as f32 / 44100.0;
210 (2.0 * PI * 440.0 * t).sin()
211 })
212 .collect();
213
214 let bins = analyzer.analyze(&samples);
215
216 assert!(!bins.is_empty());
218
219 let peak_bin = bins
221 .iter()
222 .enumerate()
223 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
224 .map(|(i, _)| i)
225 .unwrap();
226
227 let peak_freq = peak_bin as f32 * 44100.0 / 256.0;
228
229 assert!((peak_freq - 440.0).abs() < 200.0);
231 }
232
233 #[test]
234 fn test_frequency_ranges() {
235 let mut analyzer = FFTAnalyzer::new(512, 44100.0).expect("512 is valid FFT size");
236
237 let samples: Vec<f32> = (0..512).map(|i| (i as f32 * 0.1).sin()).collect();
239
240 analyzer.analyze(&samples);
241
242 let bass = analyzer.get_bass();
243 let mid = analyzer.get_mid();
244 let treble = analyzer.get_treble();
245
246 assert!(bass >= 0.0);
248 assert!(mid >= 0.0);
249 assert!(treble >= 0.0);
250 }
251
252 #[test]
253 fn test_invalid_fft_size() {
254 assert!(FFTAnalyzer::new(0, 44100.0).is_none());
256 assert!(FFTAnalyzer::new(100, 44100.0).is_none()); assert!(FFTAnalyzer::new(3, 44100.0).is_none()); assert!(FFTAnalyzer::new(256, 44100.0).is_some());
261 assert!(FFTAnalyzer::new(512, 44100.0).is_some());
262 assert!(FFTAnalyzer::new(1024, 44100.0).is_some());
263 }
264}