onedrop_codegen/
prelude.rs

1//! Standard WGSL prelude injected ahead of every translated user shader.
2//!
3//! The prelude exposes the MilkDrop 2 uniform set (audio, time, aspect,
4//! `q1..q32`, etc.) under a single `uniforms` binding so warp and comp
5//! HLSL shaders can be translated to WGSL with a known reference target.
6//!
7//! Two artefacts must stay in lockstep:
8//! - [`ShaderUniforms`] — the Rust mirror, repr-C, Pod/Zeroable. The
9//!   renderer fills it each frame from `RenderState`/`MilkContext` and
10//!   uploads it via `wgpu::Queue::write_buffer`.
11//! - [`SHADER_UNIFORMS_WGSL`] — the WGSL `struct` declaration plus the
12//!   `@group(0) @binding(0) var<uniform> uniforms: ShaderUniforms;` line.
13//!   Concatenated to the front of every translated user shader before
14//!   naga validates and the renderer compiles it.
15//!
16//! Layout discipline (matches WGSL/std140-ish uniform alignment):
17//! - All scalar `f32` fields packed at the top in groups of four (each
18//!   group is one 16-byte slot).
19//! - All `vec4<f32>` fields after, contiguous (each is exactly 16 bytes,
20//!   so no further padding is needed).
21//! - Trailing `array<vec4<f32>, N>` for `q1..q32` — eight vec4 slots = 32
22//!   floats.
23//!
24//! The size in bytes is asserted by the
25//! `tests::layout_is_byte_for_byte_compatible` test in this module
26//! against a naga-parsed WGSL module: any drift between the Rust struct
27//! and the WGSL declaration trips the test.
28
29use bytemuck::{Pod, Zeroable};
30
31/// Number of `q*` channels exposed to user shaders. MilkDrop 2 ships
32/// `q1..q32`; the engine tracks 64 internally but only the first 32 cross
33/// into shader-land.
34pub const SHADER_Q_CHANNELS: usize = 32;
35
36/// Rust mirror of the WGSL `ShaderUniforms` struct.
37///
38/// Field order, alignment, and padding must exactly match
39/// [`SHADER_UNIFORMS_WGSL`]; the `tests` module in this file enforces
40/// this with a byte-for-byte naga round-trip.
41#[repr(C)]
42#[derive(Clone, Copy, Debug, Pod, Zeroable)]
43pub struct ShaderUniforms {
44    // ---- scalar slot 0 (16 B) ----
45    /// Wall-clock time in seconds since preset load.
46    pub time: f32,
47    /// Smoothed FPS estimate.
48    pub fps: f32,
49    /// Frame counter (cast from u32).
50    pub frame: f32,
51    /// Position within the preset's display window, `[0, 1]`. 0 at load,
52    /// 1 just before the next hard cut.
53    pub progress: f32,
54
55    // ---- scalar slot 1 (16 B) ----
56    /// Bass band level (raw, ~0 baseline, can spike >1 on beats).
57    pub bass: f32,
58    /// Mid band level.
59    pub mid: f32,
60    /// Treble band level.
61    pub treb: f32,
62    /// `(bass + mid + treb) / 3`.
63    pub vol: f32,
64
65    // ---- scalar slot 2 (16 B) ----
66    /// Smoothed (attenuated) bass level — exponential moving average.
67    pub bass_att: f32,
68    pub mid_att: f32,
69    pub treb_att: f32,
70    pub vol_att: f32,
71
72    // ---- scalar slot 3 (16 B): blur level 1 / 2 ----
73    /// `bN_min` / `bN_max` are remap bounds — the comp shader applies
74    /// `lerp(min, max, sample)` to recover usable contrast from the
75    /// downsampled blur pyramid.
76    pub blur1_min: f32,
77    pub blur1_max: f32,
78    pub blur2_min: f32,
79    pub blur2_max: f32,
80
81    // ---- scalar slot 4 (16 B): blur level 3 + display gamma ----
82    pub blur3_min: f32,
83    pub blur3_max: f32,
84    /// MilkDrop's `f_gamma_adj`. Display-only multiplier (not a true gamma);
85    /// applied in the comp pass after any user shader work.
86    pub gamma_adj: f32,
87    /// Reserved for a future scalar; keeps the slot count a multiple of 4.
88    pub _reserved1: f32,
89
90    // ---- scalar slot 5 (16 B): echo + stereo ----
91    /// MilkDrop's `f_video_echo_zoom`. Comp pass samples the previous
92    /// frame's display at `(uv-0.5)*echo_zoom + 0.5` and blends it with
93    /// the current warp output. `0.0` disables the echo branch entirely
94    /// (treated as `echo_alpha = 0`).
95    pub echo_zoom: f32,
96    /// MilkDrop's `f_video_echo_alpha`. Mix weight: `0.0` = no echo, full
97    /// current frame; `1.0` = pure echo from the previous frame.
98    pub echo_alpha: f32,
99    /// MilkDrop's `n_video_echo_orientation`. 0 = no flip, 1 = flip X,
100    /// 2 = flip Y, 3 = flip both. Stored as `f32` to avoid mixed
101    /// scalar types in the uniform layout; the comp shader casts to
102    /// an integer at use site.
103    pub echo_orient: f32,
104    /// MilkDrop's `b_red_blue_stereo`. `1.0` = anaglyph mode active,
105    /// `0.0` = off. Wired through to `comp.wgsl`.
106    pub red_blue_stereo: f32,
107
108    // ---- vec4 slot 0: aspect (16 B) ----
109    /// `(aspect_x, aspect_y, 1/aspect_x, 1/aspect_y)`. MD2 historically
110    /// supplies all four to spare the shader an extra division.
111    pub aspect: [f32; 4],
112
113    // ---- vec4 slot 1: texsize_main (16 B) ----
114    /// `(width, height, 1/width, 1/height)` of the warp/comp render
115    /// target — the texture user shaders sample as `sampler_main`.
116    pub texsize: [f32; 4],
117
118    // ---- vec4 slot 2: per-preset random vector (16 B) ----
119    /// Four `[0, 1)` floats sampled once at preset load. MD2's `rand_preset`.
120    pub rand_preset: [f32; 4],
121
122    // ---- vec4 slot 3: per-frame random vector (16 B) ----
123    /// Four `[0, 1)` floats resampled every frame.
124    pub rand_frame: [f32; 4],
125
126    // ---- vec4 slots 4-7: roaming animations (4 × 16 B) ----
127    /// Pre-computed `cos` of the four MD2 "slow roam" oscillators.
128    pub slow_roam_cos: [f32; 4],
129    /// Pre-computed `sin` of the same.
130    pub slow_roam_sin: [f32; 4],
131    /// Faster-cycling roam oscillators.
132    pub roam_cos: [f32; 4],
133    pub roam_sin: [f32; 4],
134
135    // ---- vec4 slot 8: hue_shader tint (16 B) ----
136    /// MD2's per-frame RGB tint (`.xyz`). Many comp shaders multiply
137    /// `ret * hue_shader` for a slow rainbow rotation across frames;
138    /// without this populated, the multiplication degrades to identity
139    /// (`vec3(1, 1, 1)`) and the rotation disappears. The engine
140    /// computes a time-driven tri-phase oscillator each frame; the
141    /// `.w` lane is reserved for a future per-corner mix factor.
142    pub hue_shader: [f32; 4],
143
144    // ---- vec4 array: q channels (8 × 16 B = 128 B) ----
145    /// `q[0]` packs `q1..q4`, `q[1]` packs `q5..q8`, etc. Filled from
146    /// `MilkEngine::q_snapshot()`.
147    pub q: [[f32; 4]; 8],
148}
149
150impl ShaderUniforms {
151    /// Size of the uniform buffer in bytes. Useful when allocating GPU
152    /// buffers without taking `size_of::<Self>()` at the call site.
153    pub const SIZE: usize = std::mem::size_of::<Self>();
154
155    /// Pack `q1..q32` into the eight-vec4 layout the WGSL struct expects.
156    pub fn set_q_channels(&mut self, q: &[f32; SHADER_Q_CHANNELS]) {
157        for (i, slot) in self.q.iter_mut().enumerate() {
158            slot[0] = q[i * 4];
159            slot[1] = q[i * 4 + 1];
160            slot[2] = q[i * 4 + 2];
161            slot[3] = q[i * 4 + 3];
162        }
163    }
164}
165
166impl Default for ShaderUniforms {
167    fn default() -> Self {
168        Self {
169            time: 0.0,
170            fps: 60.0,
171            frame: 0.0,
172            progress: 0.0,
173            bass: 0.0,
174            mid: 0.0,
175            treb: 0.0,
176            vol: 0.0,
177            bass_att: 0.0,
178            mid_att: 0.0,
179            treb_att: 0.0,
180            vol_att: 0.0,
181            blur1_min: 0.0,
182            blur1_max: 1.0,
183            blur2_min: 0.0,
184            blur2_max: 1.0,
185            blur3_min: 0.0,
186            blur3_max: 1.0,
187            gamma_adj: 2.0,
188            _reserved1: 0.0,
189            echo_zoom: 1.0,
190            echo_alpha: 0.0,
191            echo_orient: 0.0,
192            red_blue_stereo: 0.0,
193            aspect: [1.0, 1.0, 1.0, 1.0],
194            texsize: [1.0, 1.0, 1.0, 1.0],
195            rand_preset: [0.0; 4],
196            rand_frame: [0.0; 4],
197            slow_roam_cos: [0.0; 4],
198            slow_roam_sin: [0.0; 4],
199            roam_cos: [0.0; 4],
200            roam_sin: [0.0; 4],
201            hue_shader: [1.0, 1.0, 1.0, 1.0],
202            q: [[0.0; 4]; 8],
203        }
204    }
205}
206
207/// WGSL `struct` declaration mirroring [`ShaderUniforms`] plus the standard
208/// `@group(0) @binding(0)` uniform binding.
209///
210/// Field order/types are identical; this string is concatenated to the front
211/// of every translated user shader.
212pub const SHADER_UNIFORMS_WGSL: &str = r#"
213struct ShaderUniforms {
214    // scalar slot 0
215    time: f32,
216    fps: f32,
217    frame: f32,
218    progress: f32,
219    // scalar slot 1
220    bass: f32,
221    mid: f32,
222    treb: f32,
223    vol: f32,
224    // scalar slot 2
225    bass_att: f32,
226    mid_att: f32,
227    treb_att: f32,
228    vol_att: f32,
229    // scalar slot 3
230    blur1_min: f32,
231    blur1_max: f32,
232    blur2_min: f32,
233    blur2_max: f32,
234    // scalar slot 4
235    blur3_min: f32,
236    blur3_max: f32,
237    gamma_adj: f32,
238    _reserved1: f32,
239    // scalar slot 5 — echo + anaglyph stereo
240    echo_zoom: f32,
241    echo_alpha: f32,
242    echo_orient: f32,
243    red_blue_stereo: f32,
244    // vec4 slots
245    aspect: vec4<f32>,
246    texsize: vec4<f32>,
247    rand_preset: vec4<f32>,
248    rand_frame: vec4<f32>,
249    slow_roam_cos: vec4<f32>,
250    slow_roam_sin: vec4<f32>,
251    roam_cos: vec4<f32>,
252    roam_sin: vec4<f32>,
253    // Per-frame RGB tint (`.xyz`); `.w` reserved.
254    hue_shader: vec4<f32>,
255    // q1..q32 packed as 8 × vec4
256    q: array<vec4<f32>, 8>,
257}
258
259@group(0) @binding(0)
260var<uniform> uniforms: ShaderUniforms;
261"#;
262
263/// Build a full prelude string by appending the user's shader body to the
264/// standard `ShaderUniforms` declaration. `body` is expected to already be
265/// translated to WGSL.
266pub fn with_prelude(body: &str) -> String {
267    let mut out = String::with_capacity(SHADER_UNIFORMS_WGSL.len() + body.len() + 1);
268    out.push_str(SHADER_UNIFORMS_WGSL);
269    out.push('\n');
270    out.push_str(body);
271    out
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    /// 6 scalar slots × 16 B + 9 vec4 fields × 16 B + 8 × vec4 q-array × 16 B
279    /// = (6 + 9 + 8) × 16 = 368 bytes. Slot 5 holds the echo +
280    /// red_blue_stereo block; vec4 slot 8 is `hue_shader`.
281    const EXPECTED_SIZE: usize = (6 + 9 + 8) * 16;
282
283    #[test]
284    fn rust_struct_matches_expected_layout() {
285        assert_eq!(std::mem::size_of::<ShaderUniforms>(), EXPECTED_SIZE);
286    }
287
288    #[test]
289    fn rust_struct_size_matches_wgsl_layout() {
290        // Compile the prelude through naga, look up `ShaderUniforms` in the
291        // type arena, and confirm naga's computed byte size matches the
292        // Rust mirror's `size_of`. Any drift between the two declarations
293        // (a missing field, a swapped vec4/scalar, an alignment surprise)
294        // makes the two sides disagree and trips this test.
295        let trivial_shader = format!(
296            "{SHADER_UNIFORMS_WGSL}\n@vertex fn vs() -> @builtin(position) vec4<f32> {{ return vec4<f32>(0.0); }}"
297        );
298        let module = naga::front::wgsl::parse_str(&trivial_shader).unwrap();
299        let mut info = naga::valid::Validator::new(
300            naga::valid::ValidationFlags::all(),
301            naga::valid::Capabilities::all(),
302        );
303        let _ = info.validate(&module).expect("module must validate");
304
305        let (_, ty) = module
306            .types
307            .iter()
308            .find(|(_, t)| t.name.as_deref() == Some("ShaderUniforms"))
309            .expect("ShaderUniforms struct missing from prelude");
310        // naga exposes the layout-computed size on the inner Struct variant.
311        let wgsl_size = match &ty.inner {
312            naga::TypeInner::Struct { span, .. } => *span as usize,
313            _ => panic!("ShaderUniforms is not a struct in WGSL"),
314        };
315        assert_eq!(
316            wgsl_size,
317            std::mem::size_of::<ShaderUniforms>(),
318            "Rust struct ({} B) and WGSL struct ({} B) disagree on size — \
319             field order or padding has drifted",
320            std::mem::size_of::<ShaderUniforms>(),
321            wgsl_size,
322        );
323    }
324
325    #[test]
326    fn wgsl_declaration_parses() {
327        // The prelude alone is just a struct + uniform binding — naga
328        // accepts it as a valid module fragment when wrapped in a no-op
329        // shader.
330        let trivial_shader = format!(
331            "{SHADER_UNIFORMS_WGSL}\n@vertex fn vs() -> @builtin(position) vec4<f32> {{ return vec4<f32>(0.0); }}"
332        );
333        let module = naga::front::wgsl::parse_str(&trivial_shader)
334            .expect("WGSL prelude must parse with naga");
335        // Sanity: the ShaderUniforms struct should be present in the
336        // type arena.
337        let has_struct = module.types.iter().any(|(_, ty)| {
338            matches!(&ty.inner, naga::TypeInner::Struct { .. })
339                && ty.name.as_deref() == Some("ShaderUniforms")
340        });
341        assert!(has_struct, "naga did not see ShaderUniforms in the prelude");
342    }
343
344    #[test]
345    fn with_prelude_concatenates_in_order() {
346        let body = "// user body";
347        let combined = with_prelude(body);
348        let prelude_end = combined.find(body).expect("body not found");
349        assert!(
350            combined[..prelude_end].contains("struct ShaderUniforms"),
351            "prelude must come before the user body"
352        );
353    }
354
355    #[test]
356    fn set_q_channels_packs_correctly() {
357        let mut u = ShaderUniforms::default();
358        let mut q = [0.0f32; SHADER_Q_CHANNELS];
359        for (i, slot) in q.iter_mut().enumerate() {
360            *slot = i as f32 + 1.0; // q1=1, q2=2, ...
361        }
362        u.set_q_channels(&q);
363        // q1..q4 land in u.q[0]
364        assert_eq!(u.q[0], [1.0, 2.0, 3.0, 4.0]);
365        // q29..q32 land in u.q[7]
366        assert_eq!(u.q[7], [29.0, 30.0, 31.0, 32.0]);
367    }
368
369    #[test]
370    fn default_is_pod_safe_zero_initialisable() {
371        // Pod requires "all bit patterns valid"; the default is a sane
372        // starting point but we also want the all-zeroes pattern (used by
373        // `Zeroable`) to be a valid uniform — exercised here.
374        let zeroed: ShaderUniforms = bytemuck::Zeroable::zeroed();
375        assert_eq!(zeroed.time, 0.0);
376        assert_eq!(zeroed.q, [[0.0; 4]; 8]);
377    }
378}