onedrop_engine/
audio_input.rs

1//! Real-time audio input capture using cpal.
2//!
3//! On Linux, cpal sits on top of ALSA and transparently sees both
4//! PulseAudio and PipeWire (via its Pulse compatibility layer)
5//! devices. The interesting bit for a music visualizer is selecting
6//! the right *source*: the host's default input is usually the
7//! microphone, but we want the active sink's monitor so the
8//! visualizer reacts to whatever is playing out of the speakers.
9//! [`detect_default_source`] applies that heuristic — pick the
10//! monitor of the default output sink when one exists, otherwise
11//! any `.monitor` source, otherwise the host's default input.
12//! [`with_device`] uses it transparently when no explicit name is
13//! requested.
14
15use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
16use cpal::{Device, Host, Stream, StreamConfig};
17use std::sync::{Arc, Mutex};
18use thiserror::Error;
19
20/// Audio input errors.
21#[derive(Debug, Error)]
22pub enum AudioInputError {
23    #[error("No audio input device available")]
24    NoDevice,
25
26    #[error("Failed to get default input config: {0}")]
27    ConfigError(#[from] cpal::DefaultStreamConfigError),
28
29    #[error("Failed to build audio stream: {0}")]
30    BuildStreamError(#[from] cpal::BuildStreamError),
31
32    #[error("Failed to play audio stream: {0}")]
33    PlayStreamError(#[from] cpal::PlayStreamError),
34}
35
36pub type Result<T> = std::result::Result<T, AudioInputError>;
37
38/// Coarse classification of an input source. Used by the CLI
39/// `list-audio` subcommand and the GUI device picker to flag which
40/// devices are monitor sources (capture what speakers play) versus
41/// microphones — both surface as cpal `Device`s with no built-in
42/// type tag, so we infer from the name.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum SourceKind {
45    /// PulseAudio / PipeWire sink monitor (captures system audio).
46    Monitor,
47    /// Physical or virtual microphone.
48    Microphone,
49    /// Anything we can't confidently classify — bluetooth headsets,
50    /// loopback sinks, virtual mixers, …
51    Other,
52}
53
54/// A capture device enumerated from the active cpal host.
55#[derive(Debug, Clone)]
56pub struct AudioSource {
57    /// Device name as cpal reports it. Pass back to
58    /// [`AudioInput::with_device`] to open this device specifically.
59    pub name: String,
60    /// Classification heuristic. See [`SourceKind`].
61    pub kind: SourceKind,
62    /// `true` when this source matches the host's default *input*
63    /// device — what `with_device(None)` would historically pick.
64    pub is_host_default: bool,
65    /// `true` when this source is the monitor of the host's default
66    /// *output* sink — what the autodetect logic prefers when no
67    /// name is requested. Mutually compatible with `is_host_default`
68    /// (rare, but possible if a user wires the monitor as their
69    /// default input).
70    pub is_autodetect_pick: bool,
71}
72
73/// Read a cpal `Device`'s human name through the
74/// [`DeviceTrait::description`] surface (cpal 0.17 deprecates the bare
75/// `name()` accessor). Returns `None` when the description can't be
76/// fetched — usually because the device was disconnected mid-enumeration.
77fn device_name(d: &Device) -> Option<String> {
78    d.description().ok().map(|desc| desc.name().to_string())
79}
80
81/// Classify a device name with the heuristic documented above.
82/// Pure function so we can unit-test without a real audio host.
83pub(crate) fn classify_source(name: &str) -> SourceKind {
84    let lower = name.to_ascii_lowercase();
85    // `.monitor` suffix is the canonical Pulse / PipeWire signal —
86    // every sink exposes one as `<sink_name>.monitor`.
87    if lower.ends_with(".monitor") || lower.contains("monitor of ") {
88        return SourceKind::Monitor;
89    }
90    // ALSA-style microphone names. `alsa_input.` is the Pulse prefix
91    // for the underlying ALSA capture device.
92    if lower.starts_with("alsa_input.")
93        || lower.contains("microphone")
94        || lower.contains(" mic")
95        || lower.ends_with(" mic")
96        || lower == "mic"
97    {
98        return SourceKind::Microphone;
99    }
100    SourceKind::Other
101}
102
103/// Pick the best autodetect source for a music visualizer.
104///
105/// Order of preference, mirroring what a user expects after
106/// double-clicking the binary while music is already playing:
107///
108/// 1. Monitor of the host's default *output* sink (so the
109///    visualizer reacts to whatever the user is listening to).
110/// 2. Any other `.monitor` source (covers headphones-only or
111///    USB-DAC setups where the "default" sink is something
112///    unexpected).
113/// 3. Host default input (typically the microphone) — preserves
114///    the historical behaviour when no monitor source is available
115///    (e.g. ALSA without `module-loopback`).
116fn detect_default_source(host: &Host) -> Option<Device> {
117    // Try the monitor of the default output first.
118    let default_output_name = host.default_output_device().as_ref().and_then(device_name);
119    if let Some(name) = default_output_name.as_deref() {
120        let want = format!("{}.monitor", name);
121        if let Ok(mut iter) = host.input_devices()
122            && let Some(d) = iter.find(|d| device_name(d).as_deref() == Some(want.as_str()))
123        {
124            log::info!("Audio autodetect: monitor of default sink ({want})");
125            return Some(d);
126        }
127    }
128
129    // Otherwise any monitor source.
130    if let Ok(mut iter) = host.input_devices()
131        && let Some(d) = iter.find(|d| {
132            device_name(d)
133                .map(|n| classify_source(&n) == SourceKind::Monitor)
134                .unwrap_or(false)
135        })
136    {
137        log::info!(
138            "Audio autodetect: first available monitor ({})",
139            device_name(&d).unwrap_or_else(|| "Unknown".to_string())
140        );
141        return Some(d);
142    }
143
144    // Finally fall back to the host default input.
145    host.default_input_device()
146}
147
148/// Enumerate every capture source the active cpal host exposes, with
149/// classification flags. Cheap (no streams opened); the CLI / GUI
150/// device pickers and the autodetect-preview both call this.
151pub fn list_sources() -> Vec<AudioSource> {
152    let host = cpal::default_host();
153    let host_default = host.default_input_device().as_ref().and_then(device_name);
154    let default_output = host.default_output_device().as_ref().and_then(device_name);
155    let autodetect_monitor = default_output.as_deref().map(|n| format!("{}.monitor", n));
156
157    let Ok(devices) = host.input_devices() else {
158        return Vec::new();
159    };
160    devices
161        .filter_map(|d| {
162            let name = device_name(&d)?;
163            let kind = classify_source(&name);
164            let is_host_default = host_default.as_deref() == Some(name.as_str());
165            let is_autodetect_pick =
166                autodetect_monitor.as_deref() == Some(name.as_str()) && kind == SourceKind::Monitor;
167            Some(AudioSource {
168                name,
169                kind,
170                is_host_default,
171                is_autodetect_pick,
172            })
173        })
174        .collect()
175}
176
177/// Real-time audio input capture.
178pub struct AudioInput {
179    /// Audio host
180    _host: Host,
181
182    /// Input device
183    _device: Device,
184
185    /// Input stream
186    _stream: Stream,
187
188    /// Shared buffer for audio samples
189    buffer: Arc<Mutex<Vec<f32>>>,
190
191    /// Sample rate
192    sample_rate: u32,
193}
194
195impl AudioInput {
196    /// Enumerate the names of every input device offered by cpal's
197    /// default host. Used by the GUI's Options panel to populate the
198    /// audio device picker. Devices that fail `name()` (rare —
199    /// usually disconnected mid-enumeration) are silently skipped.
200    ///
201    /// Prefer [`list_sources`] when callers want the
202    /// monitor/microphone classification flags (CLI list-audio
203    /// subcommand, future GUI picker).
204    pub fn list_input_devices() -> Vec<String> {
205        list_sources().into_iter().map(|s| s.name).collect()
206    }
207
208    /// Create a new audio input capture.
209    pub fn new() -> Result<Self> {
210        Self::with_device(None)
211    }
212
213    /// Create a new audio input capture targeting a specific device by
214    /// name. `name = None` triggers the music-visualizer autodetect:
215    /// the monitor of the host's default output sink is preferred over
216    /// the host default input (the microphone), so a fresh install
217    /// reacts to playing music without any GUI / CLI configuration.
218    /// See [`detect_default_source`] for the full preference order.
219    /// Unknown names log a warning and fall back to the same
220    /// autodetect.
221    pub fn with_device(name: Option<&str>) -> Result<Self> {
222        let host = cpal::default_host();
223
224        // Pick the named device when one is requested, otherwise run
225        // the autodetect. Unknown names log a warning and fall back —
226        // matches what happens when the user unplugs a saved device.
227        let device = if let Some(want) = name {
228            let mut matched: Option<Device> = None;
229            if let Ok(devices) = host.input_devices() {
230                for d in devices {
231                    if device_name(&d).as_deref() == Some(want) {
232                        matched = Some(d);
233                        break;
234                    }
235                }
236            }
237            match matched {
238                Some(d) => d,
239                None => {
240                    log::warn!("Input device {want:?} not found, using autodetect");
241                    detect_default_source(&host).ok_or(AudioInputError::NoDevice)?
242                }
243            }
244        } else {
245            detect_default_source(&host).ok_or(AudioInputError::NoDevice)?
246        };
247
248        log::info!(
249            "Using audio input device: {}",
250            device
251                .description()
252                .map(|d| d.to_string())
253                .unwrap_or_else(|_| "Unknown".to_string())
254        );
255
256        // Get default input config
257        let config = device.default_input_config()?;
258        let sample_rate: u32 = config.sample_rate();
259
260        log::info!(
261            "Audio input config: {} Hz, {} channels",
262            sample_rate,
263            config.channels()
264        );
265
266        // Create shared buffer
267        let buffer = Arc::new(Mutex::new(Vec::new()));
268        let buffer_clone = buffer.clone();
269
270        // Build input stream
271        let stream_config: StreamConfig = config.into();
272        let stream = device.build_input_stream(
273            &stream_config,
274            move |data: &[f32], _: &cpal::InputCallbackInfo| {
275                // Copy audio data to buffer (handle mutex poisoning gracefully)
276                if let Ok(mut buf) = buffer_clone.lock() {
277                    buf.clear();
278                    buf.extend_from_slice(data);
279                }
280            },
281            |err| {
282                log::error!("Audio input stream error: {}", err);
283            },
284            None,
285        )?;
286
287        // Start the stream
288        stream.play()?;
289
290        log::info!("Audio input stream started");
291
292        Ok(Self {
293            _host: host,
294            _device: device,
295            _stream: stream,
296            buffer,
297            sample_rate,
298        })
299    }
300
301    /// Get the latest audio samples.
302    /// Returns a copy of the current buffer.
303    pub fn get_samples(&self) -> Vec<f32> {
304        self.buffer
305            .lock()
306            .map(|buf| buf.clone())
307            .unwrap_or_default()
308    }
309
310    /// Get the sample rate.
311    pub fn sample_rate(&self) -> u32 {
312        self.sample_rate
313    }
314
315    /// Get a fixed number of samples for processing.
316    /// If not enough samples are available, returns zeros.
317    pub fn get_fixed_samples(&self, count: usize) -> Vec<f32> {
318        self.buffer
319            .lock()
320            .map(|buf| {
321                if buf.len() >= count {
322                    buf[..count].to_vec()
323                } else {
324                    // Pad with zeros if not enough samples
325                    let mut result = buf.clone();
326                    result.resize(count, 0.0);
327                    result
328                }
329            })
330            .unwrap_or_else(|_| vec![0.0; count])
331    }
332}
333
334/// Audio input with FFT analysis for bass/mid/treb extraction.
335pub struct AudioAnalysisInput {
336    /// Audio input
337    input: AudioInput,
338
339    /// FFT planner
340    fft: Arc<dyn rustfft::Fft<f32>>,
341
342    /// FFT buffer size
343    fft_size: usize,
344}
345
346impl AudioAnalysisInput {
347    /// Create a new audio analysis input.
348    ///
349    /// # Arguments
350    /// * `fft_size` - Size of FFT window (power of 2, e.g., 2048)
351    pub fn new(fft_size: usize) -> Result<Self> {
352        let input = AudioInput::new()?;
353
354        // Create FFT planner
355        let mut planner = rustfft::FftPlanner::new();
356        let fft = planner.plan_fft_forward(fft_size);
357
358        log::info!("Audio analysis initialized with FFT size {}", fft_size);
359
360        Ok(Self {
361            input,
362            fft,
363            fft_size,
364        })
365    }
366
367    /// Analyze audio and extract bass, mid, treb levels.
368    /// Returns (bass, mid, treb) in range [0.0, 1.0].
369    pub fn analyze(&self) -> (f32, f32, f32) {
370        use rustfft::num_complex::Complex;
371
372        // Get samples
373        let samples = self.input.get_fixed_samples(self.fft_size);
374
375        // Convert to complex
376        let mut buffer: Vec<Complex<f32>> = samples.iter().map(|&s| Complex::new(s, 0.0)).collect();
377
378        // Apply Hann window
379        for (i, sample) in buffer.iter_mut().enumerate() {
380            let window =
381                0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / self.fft_size as f32).cos());
382            *sample *= window;
383        }
384
385        // Perform FFT
386        self.fft.process(&mut buffer);
387
388        // Calculate magnitude spectrum
389        let magnitudes: Vec<f32> = buffer
390            .iter()
391            .take(self.fft_size / 2) // Only use first half (Nyquist)
392            .map(|c| c.norm())
393            .collect();
394
395        // Extract bass, mid, treb
396        // Frequency bins: bin_freq = sample_rate * bin_index / fft_size
397        let sample_rate = self.input.sample_rate() as f32;
398        let _bin_to_freq = |bin: usize| sample_rate * bin as f32 / self.fft_size as f32;
399
400        // Bass: 20-250 Hz
401        // Mid: 250-2000 Hz
402        // Treb: 2000-20000 Hz
403        let bass_end = (250.0 * self.fft_size as f32 / sample_rate) as usize;
404        let mid_end = (2000.0 * self.fft_size as f32 / sample_rate) as usize;
405
406        // Bounds checking to prevent panics
407        let bass_end = bass_end.max(1).min(magnitudes.len());
408        let mid_end = mid_end.max(bass_end).min(magnitudes.len());
409
410        let bass: f32 = if bass_end > 1 {
411            magnitudes[1..bass_end].iter().sum::<f32>() / (bass_end - 1) as f32
412        } else {
413            0.0
414        };
415        let mid: f32 = if mid_end > bass_end {
416            magnitudes[bass_end..mid_end].iter().sum::<f32>() / (mid_end - bass_end) as f32
417        } else {
418            0.0
419        };
420        let treb: f32 = if magnitudes.len() > mid_end {
421            magnitudes[mid_end..].iter().sum::<f32>() / (magnitudes.len() - mid_end) as f32
422        } else {
423            0.0
424        };
425
426        // Normalize to [0, 1] range (approximate)
427        let normalize = |x: f32| (x * 10.0).min(1.0);
428
429        (normalize(bass), normalize(mid), normalize(treb))
430    }
431
432    /// Get the sample rate.
433    pub fn sample_rate(&self) -> u32 {
434        self.input.sample_rate()
435    }
436
437    /// Get raw samples for further processing.
438    pub fn get_samples(&self) -> Vec<f32> {
439        self.input.get_samples()
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn classify_monitor_suffix() {
449        assert_eq!(
450            classify_source("alsa_output.pci-0000_00_1f.3.analog-stereo.monitor"),
451            SourceKind::Monitor
452        );
453        assert_eq!(
454            classify_source("Monitor of Built-in Audio"),
455            SourceKind::Monitor
456        );
457        assert_eq!(
458            classify_source("hdmi-stereo-extra1.monitor"),
459            SourceKind::Monitor
460        );
461    }
462
463    #[test]
464    fn classify_microphone() {
465        assert_eq!(
466            classify_source("alsa_input.pci-0000_00_1f.3.analog-stereo"),
467            SourceKind::Microphone
468        );
469        assert_eq!(
470            classify_source("Internal Microphone"),
471            SourceKind::Microphone
472        );
473        assert_eq!(classify_source("Headset Mic"), SourceKind::Microphone);
474    }
475
476    #[test]
477    fn classify_other_for_unknown_names() {
478        // PipeWire native node names aren't always tagged — we
479        // conservatively classify those as Other so the autodetect
480        // never *guesses* the wrong source.
481        assert_eq!(
482            classify_source("Family 17h HD Audio Controller"),
483            SourceKind::Other
484        );
485        assert_eq!(classify_source("USB Audio Device"), SourceKind::Other);
486        assert_eq!(classify_source(""), SourceKind::Other);
487    }
488
489    #[test]
490    fn classify_is_case_insensitive() {
491        assert_eq!(classify_source("Foo.MONITOR"), SourceKind::Monitor);
492        assert_eq!(
493            classify_source("BUILT-IN MICROPHONE"),
494            SourceKind::Microphone
495        );
496    }
497
498    #[test]
499    #[ignore] // Requires audio device
500    fn test_audio_input_creation() {
501        let input = AudioInput::new();
502        assert!(input.is_ok(), "Failed to create audio input");
503    }
504
505    #[test]
506    #[ignore] // Requires audio device
507    fn test_audio_analysis_creation() {
508        let input = AudioAnalysisInput::new(2048);
509        assert!(input.is_ok(), "Failed to create audio analysis input");
510    }
511
512    #[test]
513    #[ignore] // Requires audio device
514    fn test_audio_capture() {
515        let input = AudioInput::new().unwrap();
516
517        // Wait a bit for samples
518        std::thread::sleep(std::time::Duration::from_millis(100));
519
520        let samples = input.get_samples();
521        assert!(!samples.is_empty(), "Should have captured some samples");
522    }
523
524    #[test]
525    #[ignore] // Requires audio device
526    fn test_audio_analysis() {
527        let input = AudioAnalysisInput::new(2048).unwrap();
528
529        // Wait a bit for samples
530        std::thread::sleep(std::time::Duration::from_millis(100));
531
532        let (bass, mid, treb) = input.analyze();
533
534        // Should be in valid range
535        assert!((0.0..=1.0).contains(&bass), "Bass out of range: {}", bass);
536        assert!((0.0..=1.0).contains(&mid), "Mid out of range: {}", mid);
537        assert!((0.0..=1.0).contains(&treb), "Treb out of range: {}", treb);
538
539        println!(
540            "Audio levels - Bass: {:.3}, Mid: {:.3}, Treb: {:.3}",
541            bass, mid, treb
542        );
543    }
544}