1use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
16use cpal::{Device, Host, Stream, StreamConfig};
17use std::sync::{Arc, Mutex};
18use thiserror::Error;
19
20#[derive(Debug, Error)]
22pub enum AudioInputError {
23 #[error("No audio input device available")]
24 NoDevice,
25
26 #[error("Failed to get default input config: {0}")]
27 ConfigError(#[from] cpal::DefaultStreamConfigError),
28
29 #[error("Failed to build audio stream: {0}")]
30 BuildStreamError(#[from] cpal::BuildStreamError),
31
32 #[error("Failed to play audio stream: {0}")]
33 PlayStreamError(#[from] cpal::PlayStreamError),
34}
35
36pub type Result<T> = std::result::Result<T, AudioInputError>;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum SourceKind {
45 Monitor,
47 Microphone,
49 Other,
52}
53
54#[derive(Debug, Clone)]
56pub struct AudioSource {
57 pub name: String,
60 pub kind: SourceKind,
62 pub is_host_default: bool,
65 pub is_autodetect_pick: bool,
71}
72
73fn device_name(d: &Device) -> Option<String> {
78 d.description().ok().map(|desc| desc.name().to_string())
79}
80
81pub(crate) fn classify_source(name: &str) -> SourceKind {
84 let lower = name.to_ascii_lowercase();
85 if lower.ends_with(".monitor") || lower.contains("monitor of ") {
88 return SourceKind::Monitor;
89 }
90 if lower.starts_with("alsa_input.")
93 || lower.contains("microphone")
94 || lower.contains(" mic")
95 || lower.ends_with(" mic")
96 || lower == "mic"
97 {
98 return SourceKind::Microphone;
99 }
100 SourceKind::Other
101}
102
103fn detect_default_source(host: &Host) -> Option<Device> {
117 let default_output_name = host.default_output_device().as_ref().and_then(device_name);
119 if let Some(name) = default_output_name.as_deref() {
120 let want = format!("{}.monitor", name);
121 if let Ok(mut iter) = host.input_devices()
122 && let Some(d) = iter.find(|d| device_name(d).as_deref() == Some(want.as_str()))
123 {
124 log::info!("Audio autodetect: monitor of default sink ({want})");
125 return Some(d);
126 }
127 }
128
129 if let Ok(mut iter) = host.input_devices()
131 && let Some(d) = iter.find(|d| {
132 device_name(d)
133 .map(|n| classify_source(&n) == SourceKind::Monitor)
134 .unwrap_or(false)
135 })
136 {
137 log::info!(
138 "Audio autodetect: first available monitor ({})",
139 device_name(&d).unwrap_or_else(|| "Unknown".to_string())
140 );
141 return Some(d);
142 }
143
144 host.default_input_device()
146}
147
148pub fn list_sources() -> Vec<AudioSource> {
152 let host = cpal::default_host();
153 let host_default = host.default_input_device().as_ref().and_then(device_name);
154 let default_output = host.default_output_device().as_ref().and_then(device_name);
155 let autodetect_monitor = default_output.as_deref().map(|n| format!("{}.monitor", n));
156
157 let Ok(devices) = host.input_devices() else {
158 return Vec::new();
159 };
160 devices
161 .filter_map(|d| {
162 let name = device_name(&d)?;
163 let kind = classify_source(&name);
164 let is_host_default = host_default.as_deref() == Some(name.as_str());
165 let is_autodetect_pick =
166 autodetect_monitor.as_deref() == Some(name.as_str()) && kind == SourceKind::Monitor;
167 Some(AudioSource {
168 name,
169 kind,
170 is_host_default,
171 is_autodetect_pick,
172 })
173 })
174 .collect()
175}
176
177pub struct AudioInput {
179 _host: Host,
181
182 _device: Device,
184
185 _stream: Stream,
187
188 buffer: Arc<Mutex<Vec<f32>>>,
190
191 sample_rate: u32,
193}
194
195impl AudioInput {
196 pub fn list_input_devices() -> Vec<String> {
205 list_sources().into_iter().map(|s| s.name).collect()
206 }
207
208 pub fn new() -> Result<Self> {
210 Self::with_device(None)
211 }
212
213 pub fn with_device(name: Option<&str>) -> Result<Self> {
222 let host = cpal::default_host();
223
224 let device = if let Some(want) = name {
228 let mut matched: Option<Device> = None;
229 if let Ok(devices) = host.input_devices() {
230 for d in devices {
231 if device_name(&d).as_deref() == Some(want) {
232 matched = Some(d);
233 break;
234 }
235 }
236 }
237 match matched {
238 Some(d) => d,
239 None => {
240 log::warn!("Input device {want:?} not found, using autodetect");
241 detect_default_source(&host).ok_or(AudioInputError::NoDevice)?
242 }
243 }
244 } else {
245 detect_default_source(&host).ok_or(AudioInputError::NoDevice)?
246 };
247
248 log::info!(
249 "Using audio input device: {}",
250 device
251 .description()
252 .map(|d| d.to_string())
253 .unwrap_or_else(|_| "Unknown".to_string())
254 );
255
256 let config = device.default_input_config()?;
258 let sample_rate: u32 = config.sample_rate();
259
260 log::info!(
261 "Audio input config: {} Hz, {} channels",
262 sample_rate,
263 config.channels()
264 );
265
266 let buffer = Arc::new(Mutex::new(Vec::new()));
268 let buffer_clone = buffer.clone();
269
270 let stream_config: StreamConfig = config.into();
272 let stream = device.build_input_stream(
273 &stream_config,
274 move |data: &[f32], _: &cpal::InputCallbackInfo| {
275 if let Ok(mut buf) = buffer_clone.lock() {
277 buf.clear();
278 buf.extend_from_slice(data);
279 }
280 },
281 |err| {
282 log::error!("Audio input stream error: {}", err);
283 },
284 None,
285 )?;
286
287 stream.play()?;
289
290 log::info!("Audio input stream started");
291
292 Ok(Self {
293 _host: host,
294 _device: device,
295 _stream: stream,
296 buffer,
297 sample_rate,
298 })
299 }
300
301 pub fn get_samples(&self) -> Vec<f32> {
304 self.buffer
305 .lock()
306 .map(|buf| buf.clone())
307 .unwrap_or_default()
308 }
309
310 pub fn sample_rate(&self) -> u32 {
312 self.sample_rate
313 }
314
315 pub fn get_fixed_samples(&self, count: usize) -> Vec<f32> {
318 self.buffer
319 .lock()
320 .map(|buf| {
321 if buf.len() >= count {
322 buf[..count].to_vec()
323 } else {
324 let mut result = buf.clone();
326 result.resize(count, 0.0);
327 result
328 }
329 })
330 .unwrap_or_else(|_| vec![0.0; count])
331 }
332}
333
334pub struct AudioAnalysisInput {
336 input: AudioInput,
338
339 fft: Arc<dyn rustfft::Fft<f32>>,
341
342 fft_size: usize,
344}
345
346impl AudioAnalysisInput {
347 pub fn new(fft_size: usize) -> Result<Self> {
352 let input = AudioInput::new()?;
353
354 let mut planner = rustfft::FftPlanner::new();
356 let fft = planner.plan_fft_forward(fft_size);
357
358 log::info!("Audio analysis initialized with FFT size {}", fft_size);
359
360 Ok(Self {
361 input,
362 fft,
363 fft_size,
364 })
365 }
366
367 pub fn analyze(&self) -> (f32, f32, f32) {
370 use rustfft::num_complex::Complex;
371
372 let samples = self.input.get_fixed_samples(self.fft_size);
374
375 let mut buffer: Vec<Complex<f32>> = samples.iter().map(|&s| Complex::new(s, 0.0)).collect();
377
378 for (i, sample) in buffer.iter_mut().enumerate() {
380 let window =
381 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / self.fft_size as f32).cos());
382 *sample *= window;
383 }
384
385 self.fft.process(&mut buffer);
387
388 let magnitudes: Vec<f32> = buffer
390 .iter()
391 .take(self.fft_size / 2) .map(|c| c.norm())
393 .collect();
394
395 let sample_rate = self.input.sample_rate() as f32;
398 let _bin_to_freq = |bin: usize| sample_rate * bin as f32 / self.fft_size as f32;
399
400 let bass_end = (250.0 * self.fft_size as f32 / sample_rate) as usize;
404 let mid_end = (2000.0 * self.fft_size as f32 / sample_rate) as usize;
405
406 let bass_end = bass_end.max(1).min(magnitudes.len());
408 let mid_end = mid_end.max(bass_end).min(magnitudes.len());
409
410 let bass: f32 = if bass_end > 1 {
411 magnitudes[1..bass_end].iter().sum::<f32>() / (bass_end - 1) as f32
412 } else {
413 0.0
414 };
415 let mid: f32 = if mid_end > bass_end {
416 magnitudes[bass_end..mid_end].iter().sum::<f32>() / (mid_end - bass_end) as f32
417 } else {
418 0.0
419 };
420 let treb: f32 = if magnitudes.len() > mid_end {
421 magnitudes[mid_end..].iter().sum::<f32>() / (magnitudes.len() - mid_end) as f32
422 } else {
423 0.0
424 };
425
426 let normalize = |x: f32| (x * 10.0).min(1.0);
428
429 (normalize(bass), normalize(mid), normalize(treb))
430 }
431
432 pub fn sample_rate(&self) -> u32 {
434 self.input.sample_rate()
435 }
436
437 pub fn get_samples(&self) -> Vec<f32> {
439 self.input.get_samples()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn classify_monitor_suffix() {
449 assert_eq!(
450 classify_source("alsa_output.pci-0000_00_1f.3.analog-stereo.monitor"),
451 SourceKind::Monitor
452 );
453 assert_eq!(
454 classify_source("Monitor of Built-in Audio"),
455 SourceKind::Monitor
456 );
457 assert_eq!(
458 classify_source("hdmi-stereo-extra1.monitor"),
459 SourceKind::Monitor
460 );
461 }
462
463 #[test]
464 fn classify_microphone() {
465 assert_eq!(
466 classify_source("alsa_input.pci-0000_00_1f.3.analog-stereo"),
467 SourceKind::Microphone
468 );
469 assert_eq!(
470 classify_source("Internal Microphone"),
471 SourceKind::Microphone
472 );
473 assert_eq!(classify_source("Headset Mic"), SourceKind::Microphone);
474 }
475
476 #[test]
477 fn classify_other_for_unknown_names() {
478 assert_eq!(
482 classify_source("Family 17h HD Audio Controller"),
483 SourceKind::Other
484 );
485 assert_eq!(classify_source("USB Audio Device"), SourceKind::Other);
486 assert_eq!(classify_source(""), SourceKind::Other);
487 }
488
489 #[test]
490 fn classify_is_case_insensitive() {
491 assert_eq!(classify_source("Foo.MONITOR"), SourceKind::Monitor);
492 assert_eq!(
493 classify_source("BUILT-IN MICROPHONE"),
494 SourceKind::Microphone
495 );
496 }
497
498 #[test]
499 #[ignore] fn test_audio_input_creation() {
501 let input = AudioInput::new();
502 assert!(input.is_ok(), "Failed to create audio input");
503 }
504
505 #[test]
506 #[ignore] fn test_audio_analysis_creation() {
508 let input = AudioAnalysisInput::new(2048);
509 assert!(input.is_ok(), "Failed to create audio analysis input");
510 }
511
512 #[test]
513 #[ignore] fn test_audio_capture() {
515 let input = AudioInput::new().unwrap();
516
517 std::thread::sleep(std::time::Duration::from_millis(100));
519
520 let samples = input.get_samples();
521 assert!(!samples.is_empty(), "Should have captured some samples");
522 }
523
524 #[test]
525 #[ignore] fn test_audio_analysis() {
527 let input = AudioAnalysisInput::new(2048).unwrap();
528
529 std::thread::sleep(std::time::Duration::from_millis(100));
531
532 let (bass, mid, treb) = input.analyze();
533
534 assert!((0.0..=1.0).contains(&bass), "Bass out of range: {}", bass);
536 assert!((0.0..=1.0).contains(&mid), "Mid out of range: {}", mid);
537 assert!((0.0..=1.0).contains(&treb), "Treb out of range: {}", treb);
538
539 println!(
540 "Audio levels - Bass: {:.3}, Mid: {:.3}, Treb: {:.3}",
541 bass, mid, treb
542 );
543 }
544}