onedrop_codegen/
compiler.rs

1//! Shader compiler with naga validation
2
3use crate::error::{CodegenError, Result};
4use crate::prelude::SHADER_UNIFORMS_WGSL;
5use onedrop_hlsl::{MAX_USER_TEXTURE_SLOTS, TextureBindingPlan};
6use std::collections::HashMap;
7use std::fmt::Write as _;
8use std::sync::{Arc, Mutex};
9
10/// Re-export so the renderer can spell the same constant in its bind group
11/// builder. Single source of truth: the wrapper declares `MAX_USER_TEXTURE_SLOTS`
12/// bindings; the comp pipeline allocates that many entries.
13pub const USER_TEXTURE_SLOTS: usize = MAX_USER_TEXTURE_SLOTS;
14
15/// First @binding index used by the user-texture array. The
16/// 14 bindings below this are claimed by built-ins (uniforms, main texture,
17/// blur, noise, four sampler variants) — see [`USER_COMP_BINDINGS`].
18pub const USER_TEXTURE_FIRST_BINDING: u32 = 15;
19
20/// Compiled shader with validated module
21#[derive(Clone)]
22pub struct CompiledShader {
23    /// Original WGSL source
24    pub source: String,
25
26    /// Validated naga module
27    pub module: Arc<naga::Module>,
28
29    /// Module info for validation
30    pub info: Arc<naga::valid::ModuleInfo>,
31}
32
33/// Shader compiler with caching
34pub struct ShaderCompiler {
35    cache: Arc<Mutex<HashMap<String, CompiledShader>>>,
36    validator: naga::valid::Validator,
37}
38
39impl ShaderCompiler {
40    pub fn new() -> Self {
41        Self {
42            cache: Arc::new(Mutex::new(HashMap::new())),
43            validator: naga::valid::Validator::new(
44                naga::valid::ValidationFlags::all(),
45                naga::valid::Capabilities::all(),
46            ),
47        }
48    }
49
50    /// Compile and validate a WGSL shader
51    pub fn compile(&mut self, source: &str) -> Result<CompiledShader> {
52        // Check cache first
53        {
54            let cache = self.cache.lock().unwrap();
55            if let Some(compiled) = cache.get(source) {
56                log::debug!("Shader cache hit");
57                return Ok(compiled.clone());
58            }
59        }
60
61        log::debug!("Compiling shader ({} bytes)", source.len());
62
63        // Parse WGSL
64        let module = naga::front::wgsl::parse_str(source)
65            .map_err(|e| CodegenError::Compilation(format!("WGSL parse error: {:?}", e)))?;
66
67        // Validate
68        let info = self
69            .validator
70            .validate(&module)
71            .map_err(|e| CodegenError::Compilation(format!("Validation error: {:?}", e)))?;
72
73        let compiled = CompiledShader {
74            source: source.to_string(),
75            module: Arc::new(module),
76            info: Arc::new(info),
77        };
78
79        // Cache it
80        {
81            let mut cache = self.cache.lock().unwrap();
82            cache.insert(source.to_string(), compiled.clone());
83        }
84
85        log::debug!("Shader compiled and cached successfully");
86
87        Ok(compiled)
88    }
89
90    /// Get cache statistics
91    pub fn cache_stats(&self) -> CacheStats {
92        let cache = self.cache.lock().unwrap();
93        CacheStats {
94            size: cache.len(),
95            total_source_bytes: cache.values().map(|s| s.source.len()).sum(),
96        }
97    }
98
99    /// Clear the cache
100    pub fn clear_cache(&mut self) {
101        let mut cache = self.cache.lock().unwrap();
102        cache.clear();
103        log::debug!("Shader cache cleared");
104    }
105
106    /// Translate a MilkDrop 2 comp shader from HLSL, wrap it into a complete
107    /// WGSL fragment module (prelude + texture bindings + standard MD2
108    /// fragment inputs + entry points), and validate via naga.
109    ///
110    /// Returns the wrapped WGSL on success — the renderer's `CompPipeline`
111    /// can feed it to `wgpu::Device::create_shader_module` directly. On
112    /// failure, the error string captures whichever stage tripped (regex
113    /// translation, WGSL parse, validation).
114    ///
115    /// Compile rates depend on the AST-based translator in [`onedrop_hlsl`];
116    /// presets the translator can't fully translate fall back to a gamma-only
117    /// default shader at the caller level so the render path stays alive.
118    pub fn compile_user_comp_shader(&mut self, hlsl: &str) -> Result<CompiledShader> {
119        let plan = TextureBindingPlan::empty();
120        self.compile_user_comp_shader_with_plan(hlsl, &plan)
121    }
122
123    /// Variant of [`Self::compile_user_comp_shader`] that threads a
124    /// [`TextureBindingPlan`] through the translator and wrapper. The
125    /// renderer scans the HLSL, resolves each `sampler sampler_X;` against
126    /// its texture pool, and hands the result over here so the emitted
127    /// WGSL routes `tex2D(sampler_clouds, …)` onto the right user-texture
128    /// binding and exposes the matching `texsize_clouds` constant inside
129    /// `fs_main`.
130    pub fn compile_user_comp_shader_with_plan(
131        &mut self,
132        hlsl: &str,
133        plan: &TextureBindingPlan,
134    ) -> Result<CompiledShader> {
135        let translated = onedrop_hlsl::translate_shader_with_plan(hlsl, plan)
136            .map_err(|e| CodegenError::Compilation(format!("HLSL→WGSL translate: {e}")))?;
137        let wrapped = wrap_user_comp_shader_with_plan(&translated, plan);
138        self.compile(&wrapped)
139    }
140
141    /// Compile a user-authored MD2 warp shader (the per-pixel feedback
142    /// stage that runs before the composite pass) through the same
143    /// HLSL→WGSL translator the comp pipeline uses, then validate via
144    /// naga.
145    ///
146    /// The wrapper at [`wrap_user_warp_shader_with_plan`] mirrors the
147    /// comp wrapper's bind-group layout — same `ShaderUniforms`, same
148    /// texture binding numbers, same `GetPixel`/`GetBlur*` helpers — so
149    /// the renderer can reuse the comp pipeline's texture-plan
150    /// machinery. The only structural differences:
151    ///
152    /// 1. The vertex shader consumes the warp-mesh `WarpVertex`
153    ///    (`pos_clip` + `uv_warp`) instead of generating a fullscreen
154    ///    triangle. `uv_warp` is the per-vertex MD2-formula-warped UV
155    ///    the CPU writes each frame; the rasteriser interpolates it
156    ///    across each triangle.
157    /// 2. `ret` is seeded from `sampler_main_texture` at the warped
158    ///    UV. The renderer is expected to bind `prev_texture` to the
159    ///    `sampler_main_texture` slot so MD2's "previous frame at the
160    ///    warped UV" semantics fall out for free.
161    pub fn compile_user_warp_shader_with_plan(
162        &mut self,
163        hlsl: &str,
164        plan: &TextureBindingPlan,
165    ) -> Result<CompiledShader> {
166        let translated = onedrop_hlsl::translate_shader_with_plan(hlsl, plan)
167            .map_err(|e| CodegenError::Compilation(format!("HLSL→WGSL translate: {e}")))?;
168        let wrapped = wrap_user_warp_shader_with_plan(&translated, plan);
169        self.compile(&wrapped)
170    }
171
172    /// Empty-plan convenience for [`Self::compile_user_warp_shader_with_plan`].
173    pub fn compile_user_warp_shader(&mut self, hlsl: &str) -> Result<CompiledShader> {
174        let plan = TextureBindingPlan::empty();
175        self.compile_user_warp_shader_with_plan(hlsl, &plan)
176    }
177}
178
179/// Wrap a translated WGSL fragment body into a complete shader module
180/// (prelude + texture bindings + entry points).
181///
182/// The user body is pasted into `fs_main` after standard MD2 inputs (`uv`,
183/// `uv_orig`, `rad`, `ang`, `ret`) have been declared as locals.
184/// The fragment returns `vec4<f32>(ret * uniforms.gamma_adj, 1.0)`, so a
185/// well-behaved MD2 shader that assigns to `ret` produces a correct result.
186///
187/// Made `pub` so tests and the CLI can inspect the wrapped output before
188/// shipping it to naga.
189pub fn wrap_user_comp_shader(translated_body: &str) -> String {
190    wrap_user_comp_shader_with_plan(translated_body, &TextureBindingPlan::empty())
191}
192
193/// Wrap a translated WGSL fragment body for the **warp pass** into a
194/// complete shader module. Mirror of [`wrap_user_comp_shader_with_plan`]
195/// — same uniforms, same texture bindings, same MD2 private-state
196/// declarations — but with a vertex shader that consumes the warp-mesh
197/// `WarpVertex` (`pos_clip` + `uv_warp`) and a fragment that seeds `uv`
198/// from `uv_warp` and `ret` from the previous-frame texture at that UV.
199///
200/// The renderer is expected to bind `prev_texture` to the
201/// `sampler_main_texture` slot when running the warp pipeline so the
202/// MD2 convention "sample the previous frame at the warped UV" lands
203/// naturally. Helper functions (`GetPixel` / `GetBlur*`) therefore
204/// also read from `prev_texture`'s view.
205pub fn wrap_user_warp_shader_with_plan(translated_body: &str, plan: &TextureBindingPlan) -> String {
206    let (lifted_fns, body) = match translated_body.split_once(onedrop_hlsl::LIFTED_FN_SENTINEL) {
207        Some((lifted, body)) => (lifted, body.trim_start_matches('\n')),
208        None => ("", translated_body),
209    };
210
211    let mut s = String::with_capacity(
212        SHADER_UNIFORMS_WGSL.len() + USER_COMP_HELPERS.len() + translated_body.len() + 1536,
213    );
214    s.push_str(SHADER_UNIFORMS_WGSL);
215    s.push_str(USER_COMP_BINDINGS);
216    s.push_str(MD2_MATH_CONSTANTS);
217    s.push_str(MD2_NOISE_TEXSIZE_CONSTANTS);
218    s.push_str(MD2_PRIVATE_STATE_DECLS);
219    s.push_str(USER_COMP_HELPERS);
220    if !lifted_fns.trim().is_empty() {
221        s.push_str("\n// ---- lifted user-defined functions ----\n");
222        s.push_str(lifted_fns);
223        s.push('\n');
224    }
225    s.push_str(USER_WARP_VERTEX);
226    s.push_str(USER_WARP_FRAGMENT_PREFIX);
227    append_user_texsize_constants(&mut s, plan);
228    s.push_str(body);
229    s.push_str(USER_WARP_FRAGMENT_SUFFIX);
230    s
231}
232
233/// Same as [`wrap_user_comp_shader`], but emits per-preset `texsize_<NAME>`
234/// constants from the supplied [`TextureBindingPlan`]. The user-texture
235/// bindings themselves are always declared (so the bind-group layout the
236/// comp pipeline owns stays stable); only the WGSL identifiers the user
237/// body can reference (`texsize_clouds`, …) change per preset.
238pub fn wrap_user_comp_shader_with_plan(translated_body: &str, plan: &TextureBindingPlan) -> String {
239    // `translate_shader` may emit module-scope user functions in a prefix
240    // region separated by `LIFTED_FN_SENTINEL`. Split here so the lifted
241    // text lands BEFORE `fs_main` (where it belongs) and the body lands
242    // INSIDE `fs_main`. When no sentinel is present, the whole translated
243    // text is treated as body and the module-scope region stays empty.
244    let (lifted_fns, body) = match translated_body.split_once(onedrop_hlsl::LIFTED_FN_SENTINEL) {
245        Some((lifted, body)) => (lifted, body.trim_start_matches('\n')),
246        None => ("", translated_body),
247    };
248
249    let mut s = String::with_capacity(
250        SHADER_UNIFORMS_WGSL.len() + USER_COMP_HELPERS.len() + translated_body.len() + 1536,
251    );
252    s.push_str(SHADER_UNIFORMS_WGSL);
253    s.push_str(USER_COMP_BINDINGS);
254    s.push_str(MD2_MATH_CONSTANTS);
255    s.push_str(MD2_NOISE_TEXSIZE_CONSTANTS);
256    s.push_str(MD2_PRIVATE_STATE_DECLS);
257    s.push_str(USER_COMP_HELPERS);
258    if !lifted_fns.trim().is_empty() {
259        s.push_str("\n// ---- lifted user-defined functions ----\n");
260        s.push_str(lifted_fns);
261        s.push('\n');
262    }
263    s.push_str(USER_COMP_VERTEX);
264    s.push_str(USER_COMP_FRAGMENT_PREFIX);
265    append_user_texsize_constants(&mut s, plan);
266    s.push_str(body);
267    s.push_str(USER_COMP_FRAGMENT_SUFFIX);
268    s
269}
270
271/// Emit a `let texsize_<NAME>: vec4<f32> = vec4<f32>(w, h, 1/w, 1/h);` line
272/// for each filled slot in the plan. Preset code reads these directly:
273/// `tex2D(sampler_clouds, uv + texsize_clouds.zw * dt)` is a common pattern
274/// for "step one texel along the texture dimension".
275fn append_user_texsize_constants(out: &mut String, plan: &TextureBindingPlan) {
276    if plan.slot_count() == 0 {
277        return;
278    }
279    out.push_str("    // ---- user texture pack: per-preset texsize constants ----\n");
280    for slot in plan.slots() {
281        let Some(name) = &slot.pool_name else {
282            continue;
283        };
284        // Pool names are already lowercase-canonical. Skip any that contain
285        // characters that aren't valid WGSL identifiers — the codegen
286        // wrapper can't reference them otherwise. Real preset texture
287        // names in the wild are all ASCII alnum + underscore.
288        if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
289            continue;
290        }
291        let [w, h, rw, rh] = slot.texsize;
292        let _ = writeln!(
293            out,
294            "    let texsize_{name}: vec4<f32> = vec4<f32>({w:.1}, {h:.1}, {rw}, {rh});"
295        );
296    }
297}
298
299const USER_COMP_BINDINGS: &str = r#"
300@group(0) @binding(1)
301var sampler_main_texture: texture_2d<f32>;
302
303@group(0) @binding(2)
304var sampler_main: sampler;
305
306// Blur pyramid. Bindings 3/4/5 are progressively softer Gaussian-
307// blurred copies of the warp output, populated each frame by
308// `BlurPipeline`. They share `sampler_main` as their sampler — both
309// the comp pass and these blur textures want linear filtering.
310@group(0) @binding(3)
311var sampler_blur1_texture: texture_2d<f32>;
312
313@group(0) @binding(4)
314var sampler_blur2_texture: texture_2d<f32>;
315
316@group(0) @binding(5)
317var sampler_blur3_texture: texture_2d<f32>;
318
319// Procedural noise pack. Five fixed-resolution textures generated at
320// engine init from a deterministic seed. The translator routes
321// `tex2D(sampler_noise_lq|mq|hq, …)` and `tex3D(sampler_noisevol_lq|hq, …)`
322// onto these bindings; user shaders that reference unknown noise sampler
323// names still fall back to `sampler_main`.
324@group(0) @binding(6)
325var sampler_noise_lq_texture: texture_2d<f32>;
326
327@group(0) @binding(7)
328var sampler_noise_mq_texture: texture_2d<f32>;
329
330@group(0) @binding(8)
331var sampler_noise_hq_texture: texture_2d<f32>;
332
333@group(0) @binding(9)
334var sampler_noisevol_lq_texture: texture_3d<f32>;
335
336@group(0) @binding(10)
337var sampler_noisevol_hq_texture: texture_3d<f32>;
338
339// MD2 sampler variants: filter (filtered=f / point=p) × address (wrap=w /
340// clamp=c). The translator rewrites `tex2D(sampler_fw_main, uv)` →
341// `textureSample(sampler_main_texture, sampler_fw, uv)`, and similarly
342// for `_fc_` / `_pw_` / `_pc_`. Authored presets pick the variant per
343// sample site; the runtime always has all four available.
344@group(0) @binding(11)
345var sampler_fw: sampler;
346
347@group(0) @binding(12)
348var sampler_fc: sampler;
349
350@group(0) @binding(13)
351var sampler_pw: sampler;
352
353@group(0) @binding(14)
354var sampler_pc: sampler;
355
356// ---- User texture pack ----
357// Disk-loaded textures referenced by `sampler sampler_<NAME>;` declarations
358// in the preset's HLSL. The translator routes `tex2D(sampler_<NAME>, uv)`
359// onto one of these slots based on the per-preset
360// `onedrop_hlsl::TextureBindingPlan`. Unfilled slots get a 1×1 white
361// fallback at bind time so the WGSL parses and the comp pass draws cleanly
362// even when the texture pool is empty.
363@group(0) @binding(15)
364var sampler_user_0_texture: texture_2d<f32>;
365
366@group(0) @binding(16)
367var sampler_user_1_texture: texture_2d<f32>;
368
369@group(0) @binding(17)
370var sampler_user_2_texture: texture_2d<f32>;
371
372@group(0) @binding(18)
373var sampler_user_3_texture: texture_2d<f32>;
374
375@group(0) @binding(19)
376var sampler_user_4_texture: texture_2d<f32>;
377
378@group(0) @binding(20)
379var sampler_user_5_texture: texture_2d<f32>;
380
381@group(0) @binding(21)
382var sampler_user_6_texture: texture_2d<f32>;
383
384@group(0) @binding(22)
385var sampler_user_7_texture: texture_2d<f32>;
386
387// Previous-frame display for the echo blend. Bound by
388// the comp pipeline at binding 23 (immediately after the user-texture
389// slots). User shaders can sample this directly if they want to roll
390// their own echo; the default pipeline applies the standard MD2 blend
391// in `shaders/comp.wgsl`.
392@group(0) @binding(23)
393var sampler_prev_main_texture: texture_2d<f32>;
394"#;
395
396/// MD2 math constants (mirrors projectM's `milkdrop-shaders.h`).
397///
398/// User shaders reference these by short bare names. `M_PI_2` (11 hits) and
399/// `M_INV_PI_2` (8 hits) are the most common in the 200-preset survey;
400/// `M_PI` and `M_INV_PI` are added for completeness — they cost nothing.
401const MD2_MATH_CONSTANTS: &str = r#"
402const M_PI: f32 = 3.14159265358979;
403const M_PI_2: f32 = 1.57079632679489;
404const M_INV_PI: f32 = 0.31830988618379;
405const M_INV_PI_2: f32 = 0.15915494309189;
406"#;
407
408/// Procedural noise pack `texsize_*` vec4 constants. Module-scope (was
409/// previously emitted inside `fs_main`) so lifted user functions that
410/// reference `texsize_noise_lq.zw` for a per-texel offset compile cleanly.
411/// The values are fixed at engine init (see `onedrop-renderer::noise`), so
412/// promoting them from `let` to `const` is a free win.
413const MD2_NOISE_TEXSIZE_CONSTANTS: &str = r#"
414const texsize_noise_lq: vec4<f32>      = vec4<f32>(256.0, 256.0, 0.00390625, 0.00390625);
415const texsize_noise_lq_lite: vec4<f32> = vec4<f32>( 32.0,  32.0, 0.03125,    0.03125);
416const texsize_noise_mq: vec4<f32>      = vec4<f32>( 64.0,  64.0, 0.015625,   0.015625);
417const texsize_noise_hq: vec4<f32>      = vec4<f32>( 32.0,  32.0, 0.03125,    0.03125);
418const texsize_noisevol_lq: vec4<f32>   = vec4<f32>( 32.0,  32.0, 0.03125,    0.03125);
419const texsize_noisevol_hq: vec4<f32>   = vec4<f32>(  8.0,   8.0, 0.125,      0.125);
420"#;
421
422/// MD2 per-invocation private state hoisted to module scope.
423///
424/// Every binding a user shader can reference by its bare MD2 name
425/// (`q1`..`q32`, `texsize`, `time`, `bass`, `slow_roam_cos`, `uv`, `rad`,
426/// `ang`, `ret`, …) is declared here as `var<private>`. That makes them
427/// visible to lifted user-defined functions (which live at module scope
428/// alongside these declarations) — the previous design pinned them as
429/// `let` bindings inside `fs_main`, so a lifted `float3 lavcol(float t)
430/// { return pow(.5, 2*t*slow_roam_cos); }` would fail naga parse with
431/// "no definition in scope for identifier: `slow_roam_cos`".
432///
433/// WGSL `var<private>` is per-invocation, so each fragment thread has its
434/// own copy — assigning these at the top of `fs_main` from `uniforms.*` /
435/// per-pixel values is exactly what the original `let` site did, just one
436/// scope higher.
437const MD2_PRIVATE_STATE_DECLS: &str = r#"
438// ---- MD2 wrapper private state (per-invocation, module scope) ----
439// fs_main + every lifted user function reads/writes these. Values are
440// assigned at the top of fs_main; reading them before assignment is
441// undefined behavior, but every code path in our wrapper writes first.
442var<private> uv: vec2<f32>;
443var<private> uv_orig: vec2<f32>;
444var<private> rad: f32;
445var<private> ang: f32;
446var<private> ret: vec3<f32>;
447// `_md2_color` is a wrapper scratch slot — kept under a `_md2_`-prefixed
448// name to avoid colliding with the very common user-declared `color`.
449var<private> _md2_color: vec3<f32>;
450
451var<private> texsize: vec4<f32>;
452var<private> aspect: vec4<f32>;
453var<private> time: f32;
454var<private> fps: f32;
455var<private> frame: f32;
456var<private> progress: f32;
457var<private> bass: f32;
458var<private> mid: f32;
459var<private> treb: f32;
460var<private> vol: f32;
461var<private> bass_att: f32;
462var<private> mid_att: f32;
463var<private> treb_att: f32;
464var<private> vol_att: f32;
465var<private> rand_preset: vec4<f32>;
466var<private> rand_frame: vec4<f32>;
467var<private> slow_roam_cos: vec4<f32>;
468var<private> slow_roam_sin: vec4<f32>;
469var<private> roam_cos: vec4<f32>;
470var<private> roam_sin: vec4<f32>;
471var<private> blur1_min: f32;
472var<private> blur1_max: f32;
473var<private> blur2_min: f32;
474var<private> blur2_max: f32;
475var<private> blur3_min: f32;
476var<private> blur3_max: f32;
477var<private> hue_shader: vec3<f32>;
478// MD1 legacy alias for `texsize` (some old presets read `g_fTexSize.zw`).
479var<private> g_fTexSize: vec4<f32>;
480
481var<private> q1: f32;
482var<private> q2: f32;
483var<private> q3: f32;
484var<private> q4: f32;
485var<private> q5: f32;
486var<private> q6: f32;
487var<private> q7: f32;
488var<private> q8: f32;
489var<private> q9: f32;
490var<private> q10: f32;
491var<private> q11: f32;
492var<private> q12: f32;
493var<private> q13: f32;
494var<private> q14: f32;
495var<private> q15: f32;
496var<private> q16: f32;
497var<private> q17: f32;
498var<private> q18: f32;
499var<private> q19: f32;
500var<private> q20: f32;
501var<private> q21: f32;
502var<private> q22: f32;
503var<private> q23: f32;
504var<private> q24: f32;
505var<private> q25: f32;
506var<private> q26: f32;
507var<private> q27: f32;
508var<private> q28: f32;
509var<private> q29: f32;
510var<private> q30: f32;
511var<private> q31: f32;
512var<private> q32: f32;
513"#;
514
515/// MD2 helper functions every comp shader expects to find at module scope.
516///
517/// `GetPixel(uv)` returns the current main texture sample as `vec3<f32>` —
518/// the canonical MD2 convention is to discard alpha at the comp stage.
519///
520/// `GetBlur1/2/3(uv)` sample the cumulative Gaussian-blur pyramid produced
521/// by [`BlurPipeline`] each frame. The vast majority of in-the-wild comp
522/// shaders reference at least one `GetBlur*` — without a real pyramid they
523/// would degrade to `GetPixel` and lose the soft-halo look the preset
524/// depends on.
525const USER_COMP_HELPERS: &str = r#"
526fn GetPixel(uv: vec2<f32>) -> vec3<f32> {
527    return textureSample(sampler_main_texture, sampler_main, uv).rgb;
528}
529
530// MD2 alias — many in-the-wild comp shaders call GetMain(uv) instead of
531// GetPixel(uv). Same semantics: sample the main render target.
532fn GetMain(uv: vec2<f32>) -> vec3<f32> {
533    return textureSample(sampler_main_texture, sampler_main, uv).rgb;
534}
535
536// Blur-pyramid samplers. MD2 stores the three Gaussian-blurred copies of
537// the warp output in a compressed range `[blurN_min, blurN_max]` (the
538// engine writes these uniforms each frame); the canonical `GetBlurN(uv)`
539// helper re-expands that range to `[0, 1]` before handing it to the user
540// shader. Without this remap, presets that depend on the implicit clamp
541// (`color * GetBlur1(uv)`) saturate near full-white or full-black even on
542// quiet content, throwing the visual balance MD2 is tuned for.
543fn GetBlur1(uv: vec2<f32>) -> vec3<f32> {
544    let sample = textureSample(sampler_blur1_texture, sampler_main, uv).rgb;
545    return mix(vec3<f32>(blur1_min), vec3<f32>(blur1_max), sample);
546}
547
548fn GetBlur2(uv: vec2<f32>) -> vec3<f32> {
549    let sample = textureSample(sampler_blur2_texture, sampler_main, uv).rgb;
550    return mix(vec3<f32>(blur2_min), vec3<f32>(blur2_max), sample);
551}
552
553fn GetBlur3(uv: vec2<f32>) -> vec3<f32> {
554    let sample = textureSample(sampler_blur3_texture, sampler_main, uv).rgb;
555    return mix(vec3<f32>(blur3_min), vec3<f32>(blur3_max), sample);
556}
557
558/// MD2 luminance helper. 78 / 168 in-the-wild comp shaders call `lum(c)`
559/// to weight the perceptual brightness of a sample. The MD2 source uses a
560/// Rec. 601-ish weight `(0.32, 0.49, 0.29)` (note: not exactly normalised);
561/// we mirror the weights so visual output stays close to the reference.
562fn lum(c: vec3<f32>) -> f32 {
563    return dot(c, vec3<f32>(0.32, 0.49, 0.29));
564}
565"#;
566
567const USER_COMP_VERTEX: &str = r#"
568struct VertexOutput {
569    @builtin(position) position: vec4<f32>,
570    @location(0) uv: vec2<f32>,
571}
572
573@vertex
574fn vs_main(@builtin(vertex_index) idx: u32) -> VertexOutput {
575    let x = f32(idx & 1u) * 4.0 - 1.0;
576    let y = f32((idx >> 1u) & 1u) * 4.0 - 1.0;
577    var out: VertexOutput;
578    out.position = vec4<f32>(x, y, 0.0, 1.0);
579    out.uv = vec2<f32>((x + 1.0) * 0.5, 1.0 - (y + 1.0) * 0.5);
580    return out;
581}
582"#;
583
584const USER_COMP_FRAGMENT_PREFIX: &str = r#"
585@fragment
586fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
587    // All MD2 bare-name bindings live at module scope (`var<private>` in
588    // `MD2_PRIVATE_STATE_DECLS`) so lifted user-defined functions can read
589    // them. Initialise from the per-invocation inputs and the uniform
590    // struct here — WGSL `var<private>` has no static initialiser, so the
591    // first thing every fragment thread does is seed its private state.
592    uv = in.uv;
593    uv_orig = in.uv;
594    rad = length(uv - vec2<f32>(0.5)) * 1.4142135;
595    ang = atan2(uv.y - 0.5, uv.x - 0.5);
596    ret = textureSample(sampler_main_texture, sampler_main, uv).rgb;
597    _md2_color = ret;
598
599    texsize = uniforms.texsize;
600    aspect = uniforms.aspect;
601    time = uniforms.time;
602    fps = uniforms.fps;
603    frame = uniforms.frame;
604    progress = uniforms.progress;
605    bass = uniforms.bass;
606    mid = uniforms.mid;
607    treb = uniforms.treb;
608    vol = uniforms.vol;
609    bass_att = uniforms.bass_att;
610    mid_att = uniforms.mid_att;
611    treb_att = uniforms.treb_att;
612    vol_att = uniforms.vol_att;
613    rand_preset = uniforms.rand_preset;
614    rand_frame = uniforms.rand_frame;
615    slow_roam_cos = uniforms.slow_roam_cos;
616    slow_roam_sin = uniforms.slow_roam_sin;
617    roam_cos = uniforms.roam_cos;
618    roam_sin = uniforms.roam_sin;
619    blur1_min = uniforms.blur1_min;
620    blur1_max = uniforms.blur1_max;
621    blur2_min = uniforms.blur2_min;
622    blur2_max = uniforms.blur2_max;
623    blur3_min = uniforms.blur3_min;
624    blur3_max = uniforms.blur3_max;
625    // `hue_shader` is the per-frame RGB tint MD2 cycles through over time.
626    // The engine writes a time-driven tri-phase oscillator into
627    // `uniforms.hue_shader.xyz` each frame; here we just splice it in so
628    // user shaders that multiply `ret * hue_shader` get the slow rainbow
629    // rotation MD2 is tuned for.
630    hue_shader = uniforms.hue_shader.xyz;
631    g_fTexSize = uniforms.texsize;
632
633    q1 = uniforms.q[0].x;
634    q2 = uniforms.q[0].y;
635    q3 = uniforms.q[0].z;
636    q4 = uniforms.q[0].w;
637    q5 = uniforms.q[1].x;
638    q6 = uniforms.q[1].y;
639    q7 = uniforms.q[1].z;
640    q8 = uniforms.q[1].w;
641    q9 = uniforms.q[2].x;
642    q10 = uniforms.q[2].y;
643    q11 = uniforms.q[2].z;
644    q12 = uniforms.q[2].w;
645    q13 = uniforms.q[3].x;
646    q14 = uniforms.q[3].y;
647    q15 = uniforms.q[3].z;
648    q16 = uniforms.q[3].w;
649    q17 = uniforms.q[4].x;
650    q18 = uniforms.q[4].y;
651    q19 = uniforms.q[4].z;
652    q20 = uniforms.q[4].w;
653    q21 = uniforms.q[5].x;
654    q22 = uniforms.q[5].y;
655    q23 = uniforms.q[5].z;
656    q24 = uniforms.q[5].w;
657    q25 = uniforms.q[6].x;
658    q26 = uniforms.q[6].y;
659    q27 = uniforms.q[6].z;
660    q28 = uniforms.q[6].w;
661    q29 = uniforms.q[7].x;
662    q30 = uniforms.q[7].y;
663    q31 = uniforms.q[7].z;
664    q32 = uniforms.q[7].w;
665
666    // ---- begin translated user body ----
667"#;
668
669const USER_COMP_FRAGMENT_SUFFIX: &str = r#"
670    // ---- end translated user body ----
671
672    return vec4<f32>(ret * uniforms.gamma_adj, 1.0);
673}
674"#;
675
676/// Warp-pass vertex shader. Consumes the same `WarpVertex` layout the
677/// renderer uses for the default warp pipeline (`pos_clip: vec2<f32>`
678/// at `@location(0)` plus `uv_warp: vec2<f32>` at `@location(1)`) so
679/// switching between the default and user pipelines doesn't change
680/// vertex buffer plumbing on the renderer side.
681///
682/// Outputs `uv` (the per-pixel rasterised warp UV) and `uv_screen` (the
683/// screen-space `[0, 1]` UV recovered from `pos_clip`, kept for parity
684/// with `warp.wgsl`'s built-in shaders).
685const USER_WARP_VERTEX: &str = r#"
686struct VertexOutput {
687    @builtin(position) position: vec4<f32>,
688    @location(0) uv: vec2<f32>,
689    @location(1) uv_screen: vec2<f32>,
690}
691
692@vertex
693fn vs_main(
694    @location(0) pos_clip: vec2<f32>,
695    @location(1) uv_warp: vec2<f32>,
696) -> VertexOutput {
697    var out: VertexOutput;
698    out.position = vec4<f32>(pos_clip, 0.0, 1.0);
699    out.uv = uv_warp;
700    out.uv_screen = pos_clip * 0.5 + vec2<f32>(0.5);
701    return out;
702}
703"#;
704
705/// Fragment-shader header for user-authored warp shaders. Same shape
706/// as `USER_COMP_FRAGMENT_PREFIX` (seed every MD2 `var<private>` binding
707/// from `uniforms` and per-invocation inputs), but `uv`/`uv_orig` come
708/// from the rasterised warp UV — not from a fullscreen-triangle screen
709/// UV — and `ret` is seeded from the previous-frame texture at that
710/// warped UV. The renderer binds `prev_texture` to `sampler_main_texture`
711/// so `GetPixel(uv)` (defined in the shared helpers) does the right thing.
712const USER_WARP_FRAGMENT_PREFIX: &str = r#"
713@fragment
714fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
715    uv = in.uv;
716    // `uv_orig` is the *pre-warp* screen-space UV. Warp HLSL authors
717    // sometimes reference it to compare the displaced sample against
718    // the pixel's own coordinate (radial fades, vignette terms).
719    uv_orig = in.uv_screen;
720    rad = length(in.uv_screen - vec2<f32>(0.5)) * 1.4142135;
721    ang = atan2(in.uv_screen.y - 0.5, in.uv_screen.x - 0.5);
722    ret = textureSample(sampler_main_texture, sampler_main, uv).rgb;
723    _md2_color = ret;
724
725    texsize = uniforms.texsize;
726    aspect = uniforms.aspect;
727    time = uniforms.time;
728    fps = uniforms.fps;
729    frame = uniforms.frame;
730    progress = uniforms.progress;
731    bass = uniforms.bass;
732    mid = uniforms.mid;
733    treb = uniforms.treb;
734    vol = uniforms.vol;
735    bass_att = uniforms.bass_att;
736    mid_att = uniforms.mid_att;
737    treb_att = uniforms.treb_att;
738    vol_att = uniforms.vol_att;
739    rand_preset = uniforms.rand_preset;
740    rand_frame = uniforms.rand_frame;
741    slow_roam_cos = uniforms.slow_roam_cos;
742    slow_roam_sin = uniforms.slow_roam_sin;
743    roam_cos = uniforms.roam_cos;
744    roam_sin = uniforms.roam_sin;
745    blur1_min = uniforms.blur1_min;
746    blur1_max = uniforms.blur1_max;
747    blur2_min = uniforms.blur2_min;
748    blur2_max = uniforms.blur2_max;
749    blur3_min = uniforms.blur3_min;
750    blur3_max = uniforms.blur3_max;
751    hue_shader = vec3<f32>(1.0, 1.0, 1.0);
752    g_fTexSize = uniforms.texsize;
753
754    q1 = uniforms.q[0].x;
755    q2 = uniforms.q[0].y;
756    q3 = uniforms.q[0].z;
757    q4 = uniforms.q[0].w;
758    q5 = uniforms.q[1].x;
759    q6 = uniforms.q[1].y;
760    q7 = uniforms.q[1].z;
761    q8 = uniforms.q[1].w;
762    q9 = uniforms.q[2].x;
763    q10 = uniforms.q[2].y;
764    q11 = uniforms.q[2].z;
765    q12 = uniforms.q[2].w;
766    q13 = uniforms.q[3].x;
767    q14 = uniforms.q[3].y;
768    q15 = uniforms.q[3].z;
769    q16 = uniforms.q[3].w;
770    q17 = uniforms.q[4].x;
771    q18 = uniforms.q[4].y;
772    q19 = uniforms.q[4].z;
773    q20 = uniforms.q[4].w;
774    q21 = uniforms.q[5].x;
775    q22 = uniforms.q[5].y;
776    q23 = uniforms.q[5].z;
777    q24 = uniforms.q[5].w;
778    q25 = uniforms.q[6].x;
779    q26 = uniforms.q[6].y;
780    q27 = uniforms.q[6].z;
781    q28 = uniforms.q[6].w;
782    q29 = uniforms.q[7].x;
783    q30 = uniforms.q[7].y;
784    q31 = uniforms.q[7].z;
785    q32 = uniforms.q[7].w;
786
787    // ---- begin translated user body ----
788"#;
789
790/// Fragment-shader trailer for the warp pass. Same convention as the
791/// comp suffix (return `vec4(ret, 1.0)`), minus the `gamma_adj` factor —
792/// gamma is a display-side concern applied later in the comp pass.
793const USER_WARP_FRAGMENT_SUFFIX: &str = r#"
794    // ---- end translated user body ----
795
796    return vec4<f32>(ret, 1.0);
797}
798"#;
799
800impl Default for ShaderCompiler {
801    fn default() -> Self {
802        Self::new()
803    }
804}
805
806/// Cache statistics
807#[derive(Debug, Clone)]
808pub struct CacheStats {
809    pub size: usize,
810    pub total_source_bytes: usize,
811}
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816
817    #[test]
818    fn test_compile_simple_shader() {
819        let mut compiler = ShaderCompiler::new();
820
821        let source = r#"
822@vertex
823fn vs_main(@builtin(vertex_index) in_vertex_index: u32) -> @builtin(position) vec4<f32> {
824    return vec4<f32>(0.0, 0.0, 0.0, 1.0);
825}
826
827@fragment
828fn fs_main() -> @location(0) vec4<f32> {
829    return vec4<f32>(1.0, 0.0, 0.0, 1.0);
830}
831"#;
832
833        let result = compiler.compile(source);
834        assert!(result.is_ok());
835    }
836
837    #[test]
838    fn test_cache_hit() {
839        let mut compiler = ShaderCompiler::new();
840
841        let source = r#"
842@fragment
843fn fs_main() -> @location(0) vec4<f32> {
844    return vec4<f32>(1.0, 0.0, 0.0, 1.0);
845}
846"#;
847
848        // First compile
849        let result1 = compiler.compile(source);
850        assert!(result1.is_ok());
851
852        // Second compile (should hit cache)
853        let result2 = compiler.compile(source);
854        assert!(result2.is_ok());
855
856        let stats = compiler.cache_stats();
857        assert_eq!(stats.size, 1);
858    }
859
860    #[test]
861    fn test_invalid_shader() {
862        let mut compiler = ShaderCompiler::new();
863
864        let source = "invalid shader code";
865
866        let result = compiler.compile(source);
867        assert!(result.is_err());
868    }
869
870    #[test]
871    fn empty_user_body_compiles() {
872        // The empty body case validates the wrapper itself: the prelude,
873        // bindings, vertex shader, and the trivial `ret = sample(main)`
874        // fragment must all parse and validate without any user code.
875        let mut compiler = ShaderCompiler::new();
876        let result = compiler.compile_user_comp_shader("");
877        assert!(
878            result.is_ok(),
879            "empty wrapper should compile, got: {:?}",
880            result.err()
881        );
882    }
883
884    #[test]
885    fn trivial_user_body_compiles() {
886        // A minimal MD2-style body: assign sampled color back to `ret`.
887        // After translation (`float3 → vec3<f32>`, no other surprises) the
888        // wrapper produces valid WGSL.
889        let mut compiler = ShaderCompiler::new();
890        let body = "ret = ret * 0.5;";
891        let result = compiler.compile_user_comp_shader(body);
892        assert!(result.is_ok(), "got: {:?}", result.err());
893    }
894
895    #[test]
896    fn invalid_user_body_returns_error() {
897        // Garbage in the user body must surface a Compilation error rather
898        // than panic.
899        let mut compiler = ShaderCompiler::new();
900        let body = "this is not even close to HLSL @@@";
901        let result = compiler.compile_user_comp_shader(body);
902        assert!(result.is_err());
903    }
904
905    #[test]
906    fn empty_warp_body_compiles() {
907        // The warp wrapper is structurally the comp wrapper minus the
908        // fullscreen-triangle vs_main, plus a `WarpVertex`-consuming
909        // entry point. An empty body still has to parse + validate as a
910        // self-contained module (vertex + fragment, same bindings as
911        // comp, `ret` defaulted to the previous-frame sample).
912        let mut compiler = ShaderCompiler::new();
913        let result = compiler.compile_user_warp_shader("");
914        assert!(
915            result.is_ok(),
916            "empty warp wrapper should compile, got: {:?}",
917            result.err()
918        );
919    }
920
921    #[test]
922    fn trivial_warp_body_compiles() {
923        // The canonical MD2 warp pattern from corpus preset 427.milk:
924        //     ret = tex2D(sampler_main, uv).xyz * 0.85;
925        //     ret -= 0.022;
926        // Exercises every layer (translator, type-aware passes, wrapper)
927        // and proves the warp wrap accepts the same HLSL idioms the
928        // comp wrap already handles.
929        let mut compiler = ShaderCompiler::new();
930        let body = "ret = tex2D(sampler_main, uv).xyz * 0.85;\nret -= 0.022;";
931        let result = compiler.compile_user_warp_shader(body);
932        assert!(result.is_ok(), "got: {:?}", result.err());
933    }
934
935    #[test]
936    fn warp_wrapper_consumes_warpvertex_layout() {
937        // The renderer ships `pos_clip` at @location(0) and `uv_warp` at
938        // @location(1); the wrapper must consume the same layout so the
939        // default and user warp pipelines can share a vertex buffer.
940        let wrapped =
941            wrap_user_warp_shader_with_plan("", &onedrop_hlsl::TextureBindingPlan::empty());
942        assert!(
943            wrapped.contains("@location(0) pos_clip: vec2<f32>"),
944            "warp vertex shader must read pos_clip at @location(0)"
945        );
946        assert!(
947            wrapped.contains("@location(1) uv_warp: vec2<f32>"),
948            "warp vertex shader must read uv_warp at @location(1)"
949        );
950        // Differs from the comp wrapper: no `gamma_adj` factor on the
951        // final return (gamma is applied later in the comp pass).
952        assert!(
953            !wrapped.contains("ret * uniforms.gamma_adj"),
954            "warp output must not apply gamma_adj"
955        );
956        assert!(
957            wrapped.contains("return vec4<f32>(ret, 1.0)"),
958            "warp output must return ret unchanged"
959        );
960    }
961
962    #[test]
963    fn wrapper_includes_prelude_and_bindings() {
964        let wrapped = wrap_user_comp_shader("// body");
965        assert!(wrapped.contains("struct ShaderUniforms"), "no prelude");
966        assert!(
967            wrapped.contains("sampler_main_texture"),
968            "no main texture binding"
969        );
970        assert!(wrapped.contains("@vertex"), "no vertex stage");
971        assert!(wrapped.contains("@fragment"), "no fragment stage");
972        assert!(wrapped.contains("// body"), "user body not pasted");
973        // The wrapper always declares the user-texture slots,
974        // even with an empty plan — the bind-group layout the renderer
975        // owns is fixed-size, so the WGSL must reference all of them or
976        // wgpu rejects the pipeline.
977        for slot in 0..USER_TEXTURE_SLOTS {
978            let needle = format!("sampler_user_{slot}_texture");
979            assert!(
980                wrapped.contains(&needle),
981                "user texture binding {slot} missing from wrapper"
982            );
983        }
984    }
985
986    #[test]
987    fn wrapper_with_plan_emits_texsize_constants() {
988        let mut plan = TextureBindingPlan::empty();
989        plan.add_slot(
990            Some("clouds".to_string()),
991            [256.0, 128.0, 1.0 / 256.0, 1.0 / 128.0],
992            &[("sampler_clouds".to_string(), "sampler_fw")],
993        )
994        .unwrap();
995        plan.add_slot(
996            Some("worms".to_string()),
997            [64.0, 64.0, 1.0 / 64.0, 1.0 / 64.0],
998            &[("sampler_worms".to_string(), "sampler_fw")],
999        )
1000        .unwrap();
1001
1002        let wrapped = wrap_user_comp_shader_with_plan("// body", &plan);
1003        assert!(
1004            wrapped.contains("let texsize_clouds: vec4<f32> = vec4<f32>(256.0, 128.0,"),
1005            "got: {wrapped}"
1006        );
1007        assert!(
1008            wrapped.contains("let texsize_worms: vec4<f32> = vec4<f32>(64.0, 64.0,"),
1009            "got: {wrapped}"
1010        );
1011    }
1012
1013    #[test]
1014    fn user_shader_referencing_plan_texture_compiles() {
1015        // End-to-end: a real MD2-style body that samples `sampler_clouds`
1016        // must translate, validate, and use the plan-routed binding.
1017        let mut compiler = ShaderCompiler::new();
1018        let mut plan = TextureBindingPlan::empty();
1019        plan.add_slot(
1020            Some("clouds".to_string()),
1021            [256.0, 256.0, 1.0 / 256.0, 1.0 / 256.0],
1022            &[("sampler_clouds".to_string(), "sampler_fw")],
1023        )
1024        .unwrap();
1025        let hlsl = "ret = tex2D(sampler_clouds, uv).xyz;";
1026        let compiled = compiler
1027            .compile_user_comp_shader_with_plan(hlsl, &plan)
1028            .expect("plan-driven user shader must validate");
1029        // The compiled source carries the user-binding routing, not the
1030        // sampler_main fallback.
1031        assert!(compiled.source.contains("sampler_user_0_texture"));
1032        assert!(!compiled.source.contains("/*was: sampler_clouds*/"));
1033    }
1034
1035    #[test]
1036    fn test_clear_cache() {
1037        let mut compiler = ShaderCompiler::new();
1038
1039        let source = r#"
1040@fragment
1041fn fs_main() -> @location(0) vec4<f32> {
1042    return vec4<f32>(1.0, 0.0, 0.0, 1.0);
1043}
1044"#;
1045
1046        compiler.compile(source).unwrap();
1047        assert_eq!(compiler.cache_stats().size, 1);
1048
1049        compiler.clear_cache();
1050        assert_eq!(compiler.cache_stats().size, 0);
1051    }
1052}