onedrop_hlsl/
types.rs

1//! Type-aware post-translator passes for the HLSL→WGSL pipeline.
2//!
3//! After the regex rewrites in `lib.rs` produce something WGSL-shaped, the
4//! remaining failures cluster around HLSL semantics WGSL doesn't honour:
5//!
6//! 1. **Implicit scalar→vector broadcasts** in calls like
7//!    `clamp(<vec3>, 0, 1)`. HLSL broadcasts the scalars `0` and `1` to
8//!    `vec3(0)` and `vec3(1)` automatically; WGSL refuses with
9//!    `inconsistent type passed as argument #2 to clamp`.
10//! 2. **Implicit scalar←vector truncation** in declarations like
11//!    `float lum = GetPixel(uv) * c.x + …;`. HLSL takes the first
12//!    component of the vec result; WGSL refuses with `the type of lum is
13//!    expected to be f32; but got vec3<f32>`.
14//!
15//! This module fixes both with two passes that share a small symbol
16//! table built from the translated source's `var NAME: TYPE` and
17//! `let NAME: TYPE` declarations. Type inference for arbitrary
18//! expressions is intentionally heuristic: it returns a confident answer
19//! for the dominant MD2 patterns (named locals, numeric literals, vec/
20//! mat constructors, calls to a small set of known helpers) and
21//! `Unknown` otherwise. `Unknown` short-circuits both passes — the
22//! translator only injects fixes it's sure about.
23
24use std::collections::HashMap;
25
26/// WGSL types we care about for MD2 user shaders. Anything outside this
27/// set falls through as [`WgslType::Unknown`] and skips both passes.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum WgslType {
30    F32,
31    I32,
32    Vec2F,
33    Vec3F,
34    Vec4F,
35    Mat2F,
36    Mat3F,
37    Mat4F,
38    Bool,
39    Unknown,
40}
41
42impl WgslType {
43    /// `true` for the vector types we know how to broadcast to.
44    pub fn is_vec(self) -> bool {
45        matches!(self, Self::Vec2F | Self::Vec3F | Self::Vec4F)
46    }
47
48    /// `true` for scalar types — broadcast targets (the value being
49    /// wrapped, not the dominant arg type).
50    pub fn is_scalar(self) -> bool {
51        matches!(self, Self::F32 | Self::I32 | Self::Bool)
52    }
53
54    /// WGSL spelling, for emission.
55    pub fn wgsl_name(self) -> &'static str {
56        match self {
57            Self::F32 => "f32",
58            Self::I32 => "i32",
59            Self::Vec2F => "vec2<f32>",
60            Self::Vec3F => "vec3<f32>",
61            Self::Vec4F => "vec4<f32>",
62            Self::Mat2F => "mat2x2<f32>",
63            Self::Mat3F => "mat3x3<f32>",
64            Self::Mat4F => "mat4x4<f32>",
65            Self::Bool => "bool",
66            Self::Unknown => "/* unknown */",
67        }
68    }
69
70    /// Parse the type fragment that appears between `: ` and `=`/`;` in
71    /// a `var/let` declaration. Whitespace-trimmed; case-sensitive.
72    fn from_decl_str(s: &str) -> Self {
73        match s.trim() {
74            "f32" => Self::F32,
75            "i32" => Self::I32,
76            "u32" => Self::I32, // close enough for our purposes
77            "vec2<f32>" => Self::Vec2F,
78            "vec3<f32>" => Self::Vec3F,
79            "vec4<f32>" => Self::Vec4F,
80            "mat2x2<f32>" => Self::Mat2F,
81            "mat3x3<f32>" => Self::Mat3F,
82            "mat4x4<f32>" => Self::Mat4F,
83            "bool" => Self::Bool,
84            _ => Self::Unknown,
85        }
86    }
87}
88
89/// Maps locally-declared identifier → its WGSL type. Pre-seeded with the
90/// uniforms and helper-function bindings the codegen wrapper exposes
91/// inside `fs_main` (so that `texsize`, `q1..q32`, `aspect`, etc. resolve
92/// to known types).
93pub struct SymbolTable {
94    pub locals: HashMap<String, WgslType>,
95}
96
97impl SymbolTable {
98    /// Build a symbol table by scanning `var NAME: TYPE` / `let NAME: TYPE`
99    /// declarations in the translated source. Pre-seeded with the wrapper
100    /// preamble's bindings so user code can be analysed in isolation.
101    pub fn from_source(src: &str) -> Self {
102        let mut locals = HashMap::new();
103
104        // Wrapper preamble (kept in sync with `wrap_user_comp_shader` in
105        // `onedrop-codegen`). These are `let` aliases inside fs_main, so
106        // the user shader sees them unprefixed.
107        for (n, t) in WRAPPER_PRELUDE_LOCALS {
108            locals.insert((*n).to_string(), *t);
109        }
110
111        // Walk the source and pick up every `var X: T` / `let X: T`.
112        let bytes = src.as_bytes();
113        let mut i = 0;
114        while i < bytes.len() {
115            // Skip line comments and block comments so `var` etc. inside
116            // them don't pollute the symbol table.
117            if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
118                while i < bytes.len() && bytes[i] != b'\n' {
119                    i += 1;
120                }
121                continue;
122            }
123            if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
124                i += 2;
125                while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
126                    i += 1;
127                }
128                i += 2;
129                continue;
130            }
131
132            let kw_len = match keyword_at(bytes, i, &["var", "let"]) {
133                Some(n) => n,
134                None => {
135                    i += 1;
136                    continue;
137                }
138            };
139            let mut j = i + kw_len;
140            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
141                j += 1;
142            }
143            let name_start = j;
144            while j < bytes.len() && (bytes[j].is_ascii_alphanumeric() || bytes[j] == b'_') {
145                j += 1;
146            }
147            if j == name_start {
148                i = j + 1;
149                continue;
150            }
151            let name = &src[name_start..j];
152            while j < bytes.len() && bytes[j].is_ascii_whitespace() {
153                j += 1;
154            }
155            if j >= bytes.len() || bytes[j] != b':' {
156                // No type annotation — skip.
157                i = j + 1;
158                continue;
159            }
160            j += 1;
161            // Read the type fragment until `=` or `;` at depth 0.
162            let ty_start = j;
163            let mut depth_angle = 0i32;
164            while j < bytes.len() {
165                match bytes[j] {
166                    b'<' => depth_angle += 1,
167                    b'>' => depth_angle -= 1,
168                    b'=' | b';' if depth_angle == 0 => break,
169                    _ => {}
170                }
171                j += 1;
172            }
173            let ty = WgslType::from_decl_str(&src[ty_start..j]);
174            if !matches!(ty, WgslType::Unknown) {
175                // The scanner is flat — it ignores function scopes. When
176                // a name is declared with mutually-incompatible types in
177                // separate functions (`var tmp: vec2<f32>` in GetDist
178                // vs `var tmp: f32` in MinDist on the whoah/martin
179                // shared bodies), the downstream coercion passes pick
180                // whichever the scanner happened to see last and
181                // mis-type every assignment to that name in the other
182                // function. Demote to `Unknown` only when the conflict
183                // mixes scalar with vec — the assignment passes treat
184                // Unknown as "leave alone" and naga gets to validate
185                // each scope on its own. Same-shape conflicts (vec3 vs
186                // vec4) keep last-wins so the existing swizzle-assign
187                // rewrite still has a vec width to work with — those
188                // are common when a global `float3 col;` shadows a
189                // local `float4 col;` and the user's `col.rgb *= …`
190                // only needs the LHS to *be* a vec to lower cleanly.
191                use std::collections::hash_map::Entry;
192                match locals.entry(name.to_string()) {
193                    Entry::Vacant(e) => {
194                        e.insert(ty);
195                    }
196                    Entry::Occupied(mut e) => {
197                        let prev = *e.get();
198                        if prev != ty
199                            && (prev.is_scalar() != ty.is_scalar() || prev.is_vec() != ty.is_vec())
200                        {
201                            e.insert(WgslType::Unknown);
202                        } else {
203                            // Same kind (both vec / both scalar) but
204                            // different exact type — keep latest.
205                            e.insert(ty);
206                        }
207                    }
208                }
209            }
210            i = j;
211        }
212
213        Self { locals }
214    }
215
216    pub fn lookup(&self, name: &str) -> Option<WgslType> {
217        self.locals.get(name).copied()
218    }
219
220    /// Best-effort type inference for an expression text. Returns a
221    /// confident type for the dominant patterns, [`WgslType::Unknown`]
222    /// otherwise — callers must treat `Unknown` as "do not modify".
223    ///
224    /// Recognised:
225    /// - A numeric literal (with or without `f`/`i`/`u` suffix) → `F32`.
226    /// - A bare identifier in the symbol table → that type.
227    /// - A constructor call `vec3<f32>(...)` etc. → that type.
228    /// - A call to a known helper (`GetPixel`, `GetBlur1..3`, `lum`,
229    ///   `textureSample`, etc.) → its return type.
230    /// - A swizzle access `<expr>.<comp>` where the prefix is a vec
231    ///   type — narrowed to `F32` for `.x/.y/.z/.w/.r/.g/.b/.a` of length 1,
232    ///   `Vec2F` for 2, `Vec3F` for 3, `Vec4F` for 4.
233    /// - A binary op (`+`, `-`, `*`, `/`) where one side is vec → the
234    ///   widest vec type. Falls back to `F32` when both sides are scalar.
235    pub fn infer_expr_type(&self, expr: &str) -> WgslType {
236        // Strip C-style comments first — `/*was: foo*/` markers from the
237        // tex2D rewrite contain `/` and `*`, which the binop splitter
238        // would otherwise mistake for top-level division/multiplication
239        // and slice the expression in two. We only strip for inference;
240        // the visible output keeps the comments for debuggability.
241        let stripped = strip_comments(expr);
242        let expr = strip_outer_parens(stripped.trim());
243
244        if expr.is_empty() {
245            return WgslType::Unknown;
246        }
247
248        // Numeric literal.
249        if is_numeric_literal(expr) {
250            return WgslType::F32;
251        }
252
253        // Bare identifier.
254        if is_identifier(expr) {
255            if let Some(t) = self.lookup(expr) {
256                return t;
257            }
258            // A stray unknown identifier is just unknown — could be a
259            // user variable we missed.
260            return WgslType::Unknown;
261        }
262
263        // Constructor calls and known helper calls.
264        if let Some((head, args)) = split_call(expr) {
265            if let Some(t) = constructor_type(head) {
266                return t;
267            }
268            if let Some(t) = known_call_return_type(head) {
269                return t;
270            }
271            // Polymorphic builtins (`clamp`, `min`, `max`, …) take the
272            // *element type* of their args. We pick the smallest vec
273            // among the args (matching the broadcast pass's "narrow to
274            // smallest" rule); if no vec is present, fall back to F32.
275            if POLY_BUILTINS.contains(&head) {
276                let mut smallest: Option<WgslType> = None;
277                for a in split_top_level_commas(args) {
278                    let t = self.infer_expr_type(a.trim());
279                    if t.is_vec() {
280                        smallest = Some(match smallest {
281                            None => t,
282                            Some(s) => narrower(s, t),
283                        });
284                    }
285                }
286                return smallest.unwrap_or(WgslType::F32);
287            }
288        }
289
290        // Binary op at the top level FIRST — `a + b.x` is `a + (b.x)`,
291        // so a top-level `+` outranks the swizzle's `.`. Split on
292        // `+ - * /` at depth 0 and pick the widest type among operands.
293        if let Some(operands) = split_binop_operands(expr) {
294            let mut widest = WgslType::Unknown;
295            for op in &operands {
296                let t = self.infer_expr_type(op);
297                widest = widen(widest, t);
298            }
299            if !matches!(widest, WgslType::Unknown) {
300                return widest;
301            }
302        }
303
304        // Swizzle: split on the *last* `.` at depth 0. Only reached for
305        // single-term expressions like `c.xyz` — chained binops were
306        // handled above.
307        if let Some((prefix, comp)) = split_last_swizzle(expr)
308            && is_swizzle_components(comp)
309        {
310            let prefix_ty = self.infer_expr_type(prefix);
311            if prefix_ty.is_vec() {
312                return swizzle_target_type(comp.len());
313            }
314        }
315
316        // Unary prefix (`-x`, `+x`, `!x`) — same type as operand. Placed
317        // AFTER the top-level binop split so `a - b` still parses as a
318        // subtraction (operands `a` and `b`); only single-term primaries
319        // like `-r` or `-(rad-.5)` fall through to here. Without this,
320        // `clamp(v, -r, r)` saw arg #2 as Unknown and the broadcast pass
321        // wrapped only arg #3 — surfacing as `inconsistent type passed
322        // as argument #2 to clamp` (the single largest residual cluster
323        // on the post-f83a9cc 2000-sample, ≈24 presets).
324        if matches!(expr.as_bytes().first(), Some(b'-' | b'+' | b'!')) {
325            return self.infer_expr_type(&expr[1..]);
326        }
327
328        WgslType::Unknown
329    }
330}
331
332/// Wrapper preamble locals exposed to user shaders. Mirrors the `let`
333/// aliases in `onedrop-codegen::wrap_user_comp_shader::USER_COMP_FRAGMENT_PREFIX`.
334const WRAPPER_PRELUDE_LOCALS: &[(&str, WgslType)] = &[
335    ("uv", WgslType::Vec2F),
336    ("uv_orig", WgslType::Vec2F),
337    ("rad", WgslType::F32),
338    ("ang", WgslType::F32),
339    ("ret", WgslType::Vec3F),
340    ("color", WgslType::Vec3F),
341    ("texsize", WgslType::Vec4F),
342    ("aspect", WgslType::Vec4F),
343    ("time", WgslType::F32),
344    ("fps", WgslType::F32),
345    ("frame", WgslType::F32),
346    ("progress", WgslType::F32),
347    ("bass", WgslType::F32),
348    ("mid", WgslType::F32),
349    ("treb", WgslType::F32),
350    ("vol", WgslType::F32),
351    ("bass_att", WgslType::F32),
352    ("mid_att", WgslType::F32),
353    ("treb_att", WgslType::F32),
354    ("vol_att", WgslType::F32),
355    ("rand_preset", WgslType::Vec4F),
356    ("rand_frame", WgslType::Vec4F),
357    ("slow_roam_cos", WgslType::Vec4F),
358    ("slow_roam_sin", WgslType::Vec4F),
359    ("roam_cos", WgslType::Vec4F),
360    ("roam_sin", WgslType::Vec4F),
361    ("blur1_min", WgslType::F32),
362    ("blur1_max", WgslType::F32),
363    ("blur2_min", WgslType::F32),
364    ("blur2_max", WgslType::F32),
365    ("blur3_min", WgslType::F32),
366    ("blur3_max", WgslType::F32),
367    ("hue_shader", WgslType::Vec3F),
368    ("g_fTexSize", WgslType::Vec4F),
369    ("texsize_noise_lq", WgslType::Vec4F),
370    ("texsize_noise_lq_lite", WgslType::Vec4F),
371    ("texsize_noise_mq", WgslType::Vec4F),
372    ("texsize_noise_hq", WgslType::Vec4F),
373    ("texsize_noisevol_lq", WgslType::Vec4F),
374    ("texsize_noisevol_hq", WgslType::Vec4F),
375    // q1..q32 and the math constants are scalar f32.
376    ("q1", WgslType::F32),
377    ("q2", WgslType::F32),
378    ("q3", WgslType::F32),
379    ("q4", WgslType::F32),
380    ("q5", WgslType::F32),
381    ("q6", WgslType::F32),
382    ("q7", WgslType::F32),
383    ("q8", WgslType::F32),
384    ("q9", WgslType::F32),
385    ("q10", WgslType::F32),
386    ("q11", WgslType::F32),
387    ("q12", WgslType::F32),
388    ("q13", WgslType::F32),
389    ("q14", WgslType::F32),
390    ("q15", WgslType::F32),
391    ("q16", WgslType::F32),
392    ("q17", WgslType::F32),
393    ("q18", WgslType::F32),
394    ("q19", WgslType::F32),
395    ("q20", WgslType::F32),
396    ("q21", WgslType::F32),
397    ("q22", WgslType::F32),
398    ("q23", WgslType::F32),
399    ("q24", WgslType::F32),
400    ("q25", WgslType::F32),
401    ("q26", WgslType::F32),
402    ("q27", WgslType::F32),
403    ("q28", WgslType::F32),
404    ("q29", WgslType::F32),
405    ("q30", WgslType::F32),
406    ("q31", WgslType::F32),
407    ("q32", WgslType::F32),
408    ("M_PI", WgslType::F32),
409    ("M_PI_2", WgslType::F32),
410    ("M_INV_PI", WgslType::F32),
411    ("M_INV_PI_2", WgslType::F32),
412];
413
414fn keyword_at(bytes: &[u8], i: usize, kws: &[&str]) -> Option<usize> {
415    if i > 0 {
416        let prev = bytes[i - 1];
417        if prev.is_ascii_alphanumeric() || prev == b'_' {
418            return None;
419        }
420    }
421    for kw in kws {
422        let len = kw.len();
423        if i + len <= bytes.len() && &bytes[i..i + len] == kw.as_bytes() {
424            let next = bytes.get(i + len).copied();
425            let is_id = next.is_some_and(|c| c.is_ascii_alphanumeric() || c == b'_');
426            if !is_id {
427                return Some(len);
428            }
429        }
430    }
431    None
432}
433
434/// Replace `/* ... */` and `// ...` comments with spaces in-place so the
435/// resulting string has the same length and column positions as the
436/// original — crucial because [`SymbolTable::infer_expr_type`] doesn't
437/// remap offsets when reasoning about an arg's text.
438fn strip_comments(src: &str) -> String {
439    let bytes = src.as_bytes();
440    let mut out = Vec::with_capacity(bytes.len());
441    let mut i = 0;
442    while i < bytes.len() {
443        if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
444            while i < bytes.len() && bytes[i] != b'\n' {
445                out.push(b' ');
446                i += 1;
447            }
448        } else if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
449            // Walk to the matching `*/`. Replace every byte (including the
450            // markers) with spaces; preserve newlines so line numbers
451            // stay aligned.
452            while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
453                out.push(if bytes[i] == b'\n' { b'\n' } else { b' ' });
454                i += 1;
455            }
456            // Replace the closing `*/` if found.
457            if i + 1 < bytes.len() {
458                out.push(b' ');
459                out.push(b' ');
460                i += 2;
461            }
462        } else {
463            out.push(bytes[i]);
464            i += 1;
465        }
466    }
467    // SAFETY: input was UTF-8 and we only replaced bytes inside ASCII
468    // comment runs with ASCII spaces / newlines, so the result is still
469    // valid UTF-8.
470    String::from_utf8(out).expect("comment-stripping preserved UTF-8")
471}
472
473fn strip_outer_parens(expr: &str) -> &str {
474    let mut e = expr;
475    loop {
476        let trimmed = e.trim();
477        if !trimmed.starts_with('(') || !trimmed.ends_with(')') {
478            return trimmed;
479        }
480        // Only strip if the outermost parens enclose the whole expression.
481        let bytes = trimmed.as_bytes();
482        let mut depth = 0i32;
483        let mut closed_at = None;
484        for (i, &b) in bytes.iter().enumerate() {
485            match b {
486                b'(' => depth += 1,
487                b')' => {
488                    depth -= 1;
489                    if depth == 0 {
490                        closed_at = Some(i);
491                        break;
492                    }
493                }
494                _ => {}
495            }
496        }
497        if closed_at == Some(bytes.len() - 1) {
498            e = &trimmed[1..bytes.len() - 1];
499        } else {
500            return trimmed;
501        }
502    }
503}
504
505fn is_numeric_literal(s: &str) -> bool {
506    let s = s.trim();
507    if s.is_empty() {
508        return false;
509    }
510    // Strip a trailing type suffix.
511    let s = s.trim_end_matches(['f', 'i', 'u', 'h']);
512    // A leading `-` is fine — we may see `-1.0`.
513    let core = s.strip_prefix('-').unwrap_or(s);
514    !core.is_empty()
515        && core
516            .chars()
517            .all(|c| c.is_ascii_digit() || c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-')
518        && core.chars().any(|c| c.is_ascii_digit())
519}
520
521fn is_identifier(s: &str) -> bool {
522    let mut chars = s.chars();
523    match chars.next() {
524        Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
525        _ => return false,
526    }
527    chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
528}
529
530/// Split `<head>(<args>)` into (head, args). Returns `None` if the
531/// expression isn't a bare call (e.g. has a trailing `.foo`).
532fn split_call(expr: &str) -> Option<(&str, &str)> {
533    let bytes = expr.as_bytes();
534    let open = expr.find('(')?;
535    if open == 0 {
536        return None;
537    }
538    // The head must be a single identifier or `vecN<f32>` etc.
539    // Check that closing paren is at the very end.
540    if bytes.last() != Some(&b')') {
541        return None;
542    }
543    // Walk to confirm balance.
544    let mut depth = 0i32;
545    for (i, &b) in bytes.iter().enumerate().skip(open) {
546        match b {
547            b'(' => depth += 1,
548            b')' => {
549                depth -= 1;
550                if depth == 0 {
551                    if i != bytes.len() - 1 {
552                        return None;
553                    }
554                    return Some((expr[..open].trim(), &expr[open + 1..i]));
555                }
556            }
557            _ => {}
558        }
559    }
560    None
561}
562
563fn constructor_type(head: &str) -> Option<WgslType> {
564    Some(match head {
565        "vec2<f32>" => WgslType::Vec2F,
566        "vec3<f32>" => WgslType::Vec3F,
567        "vec4<f32>" => WgslType::Vec4F,
568        "mat2x2<f32>" => WgslType::Mat2F,
569        "mat3x3<f32>" => WgslType::Mat3F,
570        "mat4x4<f32>" => WgslType::Mat4F,
571        "f32" => WgslType::F32,
572        "i32" => WgslType::I32,
573        _ => return None,
574    })
575}
576
577fn known_call_return_type(head: &str) -> Option<WgslType> {
578    Some(match head {
579        // MD2 wrapper helpers (in `onedrop-codegen::USER_COMP_HELPERS`).
580        "GetPixel" | "GetBlur1" | "GetBlur2" | "GetBlur3" => WgslType::Vec3F,
581        "lum" => WgslType::F32,
582        // WGSL builtins that return a known type regardless of args.
583        "length" | "distance" | "dot" => WgslType::F32,
584        "cross" => WgslType::Vec3F,
585        "textureSample" => WgslType::Vec4F,
586        // Same-type-as-arg builtins — best heuristic is "preserve vec3"
587        // because the dominant MD2 body manipulates vec3 colour. A
588        // conservative `Unknown` would block the broadcast pass; pick
589        // Vec3F since it's the dominant case and a wrong guess only
590        // means we fail to inject a needed broadcast (graceful: same
591        // behaviour as before this pass).
592        // (Intentionally not enumerated — return Unknown so the pass
593        // falls back on its own per-arg analysis.)
594        _ => return None,
595    })
596}
597
598/// Return `(prefix, components)` where `<expr>.<components>` is the swizzle.
599/// Splits on the last `.` at depth 0.
600fn split_last_swizzle(expr: &str) -> Option<(&str, &str)> {
601    let bytes = expr.as_bytes();
602    let mut depth_paren = 0i32;
603    let mut depth_angle = 0i32;
604    let mut last_dot = None;
605    for (i, &b) in bytes.iter().enumerate() {
606        match b {
607            b'(' => depth_paren += 1,
608            b')' => depth_paren -= 1,
609            b'<' => depth_angle += 1,
610            b'>' => depth_angle -= 1,
611            b'.' if depth_paren == 0 && depth_angle == 0 => last_dot = Some(i),
612            _ => {}
613        }
614    }
615    let dot = last_dot?;
616    // Don't split numeric literals.
617    let pre = expr[..dot].trim();
618    if pre.is_empty() || is_numeric_literal(pre) {
619        return None;
620    }
621    Some((pre, &expr[dot + 1..]))
622}
623
624fn is_swizzle_components(s: &str) -> bool {
625    !s.is_empty()
626        && s.len() <= 4
627        && s.chars()
628            .all(|c| matches!(c, 'x' | 'y' | 'z' | 'w' | 'r' | 'g' | 'b' | 'a'))
629}
630
631/// Component count for vec types, 0 for everything else. Used by the
632/// truncation pass to size the trailing `.xy`/`.xyz` swizzle when
633/// narrowing a wider vec into a smaller one.
634pub(crate) fn vec_size(t: WgslType) -> usize {
635    match t {
636        WgslType::Vec2F => 2,
637        WgslType::Vec3F => 3,
638        WgslType::Vec4F => 4,
639        _ => 0,
640    }
641}
642
643/// Returns the WgslType for a vec of the given component count. `1` maps
644/// to scalar f32, `2..4` to the matching vec type, anything else to
645/// `Unknown`. Convenience for AST-driven rewrites that need to construct
646/// a target type from an inferred component count.
647pub(crate) fn vec_of_size(n: usize) -> WgslType {
648    match n {
649        1 => WgslType::F32,
650        2 => WgslType::Vec2F,
651        3 => WgslType::Vec3F,
652        4 => WgslType::Vec4F,
653        _ => WgslType::Unknown,
654    }
655}
656
657fn swizzle_target_type(len: usize) -> WgslType {
658    match len {
659        1 => WgslType::F32,
660        2 => WgslType::Vec2F,
661        3 => WgslType::Vec3F,
662        4 => WgslType::Vec4F,
663        _ => WgslType::Unknown,
664    }
665}
666
667/// Split an expression on top-level binary ops `+ - * /` and return the
668/// operands. Returns `None` if no top-level op found (i.e. the expression
669/// is a single term).
670fn split_binop_operands(expr: &str) -> Option<Vec<&str>> {
671    let bytes = expr.as_bytes();
672    let mut depth_paren = 0i32;
673    let mut depth_angle = 0i32;
674    let mut depth_bracket = 0i32;
675    let mut splits = Vec::new();
676    let mut prev_was_op_or_start = true;
677    for (i, &b) in bytes.iter().enumerate() {
678        match b {
679            b'(' => {
680                depth_paren += 1;
681                prev_was_op_or_start = false;
682            }
683            b')' => {
684                depth_paren -= 1;
685                prev_was_op_or_start = false;
686            }
687            b'<' => {
688                depth_angle += 1;
689                prev_was_op_or_start = false;
690            }
691            b'>' => {
692                depth_angle -= 1;
693                prev_was_op_or_start = false;
694            }
695            b'[' => {
696                depth_bracket += 1;
697                prev_was_op_or_start = false;
698            }
699            b']' => {
700                depth_bracket -= 1;
701                prev_was_op_or_start = false;
702            }
703            b'+' | b'-' | b'*' | b'/'
704                if depth_paren == 0
705                    && depth_angle == 0
706                    && depth_bracket == 0
707                    && !prev_was_op_or_start =>
708            {
709                splits.push(i);
710                prev_was_op_or_start = true;
711            }
712            c if c.is_ascii_whitespace() => {}
713            _ => prev_was_op_or_start = false,
714        }
715    }
716    if splits.is_empty() {
717        return None;
718    }
719    let mut out = Vec::with_capacity(splits.len() + 1);
720    let mut start = 0;
721    for &s in &splits {
722        out.push(&expr[start..s]);
723        start = s + 1;
724    }
725    out.push(&expr[start..]);
726    Some(out)
727}
728
729/// Pick the "widest" of two types — vec wins over scalar; among vecs
730/// keep the larger; on conflict default to the first non-unknown.
731fn widen(a: WgslType, b: WgslType) -> WgslType {
732    match (a, b) {
733        (WgslType::Unknown, x) | (x, WgslType::Unknown) => x,
734        (x, y) if x == y => x,
735        (WgslType::Vec4F, _) | (_, WgslType::Vec4F) => WgslType::Vec4F,
736        (WgslType::Vec3F, _) | (_, WgslType::Vec3F) => WgslType::Vec3F,
737        (WgslType::Vec2F, _) | (_, WgslType::Vec2F) => WgslType::Vec2F,
738        _ => a,
739    }
740}
741
742// ---------------------------------------------------------------------------
743// Pass: scalar→vector broadcast injection in known multi-arg builtins
744// ---------------------------------------------------------------------------
745
746/// Built-in functions that require all arguments to share a type and
747/// don't accept HLSL-style scalar↔vector broadcasts in WGSL. For each
748/// call to one of these, if any arg has a known vec type and another
749/// is a scalar, the scalar is wrapped in the matching vec constructor.
750const BROADCAST_BUILTINS: &[&str] = &[
751    "clamp",
752    "min",
753    "max",
754    "mix",
755    "step",
756    "smoothstep",
757    "pow",
758    // `dot(a, b)` requires both args to share a vec size; corpus has
759    // `dot(vec2, vec3)` shapes that HLSL silently truncates. Adding it
760    // here narrows the larger arg to the smaller vec.
761    "dot",
762    // `cross(a, b)` is strictly vec3×vec3 in WGSL; corpus pattern is
763    // `cross(vec3, textureSample(...))` — arg #2 is vec4 and naga
764    // rejects with `wrong type passed as argument #2 to cross`. The
765    // narrower-vec rule truncates to vec3.
766    "cross",
767];
768
769/// Builtins whose return type matches their arg type (HLSL "polymorphic"
770/// element-typed functions). Used by [`SymbolTable::infer_expr_type`] —
771/// the same set as [`BROADCAST_BUILTINS`] plus a few more pure-passthrough
772/// math functions we don't need to broadcast-rewrite but whose return
773/// type is "same as arg".
774const POLY_BUILTINS: &[&str] = &[
775    "clamp",
776    "min",
777    "max",
778    "mix",
779    "step",
780    "smoothstep",
781    "pow",
782    "abs",
783    "sign",
784    "floor",
785    "ceil",
786    "fract",
787    "exp",
788    "log",
789    "sin",
790    "cos",
791    "tan",
792    "sqrt",
793    "normalize",
794];
795
796/// Walk `src` looking for calls to broadcast-prone builtins; rewrite
797/// scalar arguments to vec constructors where the dominant arg is a vec.
798///
799/// **Recursion strategy**: for every matched call, we first recursively
800/// apply the broadcast rewrite to each argument, *then* decide whether
801/// to broadcast at the current level. This ensures that
802/// `pow(clamp(<vec>, 0, 1), 0.5)` both rewrites the inner `clamp` and
803/// the outer `pow` in a single sweep, instead of needing a fixed-point
804/// loop with the walker skipping past unrewritten inner calls.
805pub fn inject_broadcasts(src: &str, table: &SymbolTable) -> String {
806    let bytes = src.as_bytes();
807    let mut out = String::with_capacity(src.len() + 64);
808    let mut i = 0;
809
810    while i < bytes.len() {
811        // Try to match a builtin name on a word boundary.
812        let mut matched = None;
813        for name in BROADCAST_BUILTINS {
814            let len = name.len();
815            if i + len < bytes.len()
816                && &bytes[i..i + len] == name.as_bytes()
817                && bytes[i + len] == b'('
818                && (i == 0 || !(bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_'))
819            {
820                matched = Some(*name);
821                break;
822            }
823        }
824
825        let Some(name) = matched else {
826            out.push(bytes[i] as char);
827            i += 1;
828            continue;
829        };
830
831        // Balance parens to find the call's args.
832        let arg_start = i + name.len() + 1;
833        let mut j = arg_start;
834        let mut depth = 1i32;
835        while j < bytes.len() {
836            match bytes[j] {
837                b'(' => depth += 1,
838                b')' => {
839                    depth -= 1;
840                    if depth == 0 {
841                        break;
842                    }
843                }
844                _ => {}
845            }
846            j += 1;
847        }
848        if j >= bytes.len() {
849            // Unbalanced — emit verbatim and bail.
850            out.push_str(&src[i..]);
851            return out;
852        }
853
854        let args_text = &src[arg_start..j];
855        let raw_args = split_top_level_commas(args_text);
856
857        // Recurse first so any nested broadcast-prone call gets rewritten
858        // before we infer the type at this level. The recursed text is
859        // the source of truth for both the inferred-type lookup and the
860        // emit step below.
861        let rewritten_args: Vec<String> = raw_args
862            .iter()
863            .map(|a| inject_broadcasts(a.trim(), table))
864            .collect();
865        let arg_types: Vec<WgslType> = rewritten_args
866            .iter()
867            .map(|a| table.infer_expr_type(a))
868            .collect();
869
870        // Pick the *smallest* vec among args as the target. HLSL allows
871        // `max(vec3, vec4)` (truncating the vec4 to vec3); WGSL refuses
872        // and demands all args share a type. The smallest vec wins so we
873        // truncate larger args down — matches HLSL's "first arg
874        // constrains" rule for cases like `ret = max(ret_vec3, tex2D())`.
875        let target = arg_types
876            .iter()
877            .copied()
878            .filter(|t| t.is_vec())
879            .reduce(narrower);
880
881        out.push_str(name);
882        out.push('(');
883        match target {
884            Some(vec_ty) => {
885                for (k, (arg, ty)) in rewritten_args.iter().zip(arg_types.iter()).enumerate() {
886                    if k > 0 {
887                        out.push_str(", ");
888                    }
889                    let wrap = arg_wrap(*ty, vec_ty, arg);
890                    match wrap {
891                        ArgWrap::Broadcast => {
892                            out.push_str(vec_ty.wgsl_name());
893                            out.push('(');
894                            out.push_str(arg);
895                            out.push(')');
896                        }
897                        ArgWrap::Truncate(swizzle) => {
898                            out.push('(');
899                            out.push_str(arg);
900                            out.push(')');
901                            out.push_str(swizzle);
902                        }
903                        ArgWrap::None => {
904                            out.push_str(arg);
905                        }
906                    }
907                }
908            }
909            None => {
910                // No vec dominance — emit args (post-recursion) verbatim.
911                for (k, arg) in rewritten_args.iter().enumerate() {
912                    if k > 0 {
913                        out.push_str(", ");
914                    }
915                    out.push_str(arg);
916                }
917            }
918        }
919        out.push(')');
920        i = j + 1;
921    }
922
923    out
924}
925
926enum ArgWrap {
927    None,
928    Broadcast,
929    Truncate(&'static str),
930}
931
932/// Decide how to coerce an arg of `arg_ty` to `target` (the smallest vec
933/// among the call's args). Scalars get broadcast; larger vecs get
934/// truncated; equal vecs and unknowns mostly pass through (except for
935/// literal scalars, which we still broadcast even when type inference
936/// returned Unknown).
937fn arg_wrap(arg_ty: WgslType, target: WgslType, arg: &str) -> ArgWrap {
938    if arg_ty == target {
939        return ArgWrap::None;
940    }
941    if arg_ty.is_scalar() {
942        return ArgWrap::Broadcast;
943    }
944    if matches!(arg_ty, WgslType::Unknown) && is_numeric_literal(arg.trim()) {
945        return ArgWrap::Broadcast;
946    }
947    if arg_ty.is_vec() && target.is_vec() && vec_size(arg_ty) > vec_size(target) {
948        return ArgWrap::Truncate(match vec_size(target) {
949            2 => ".xy",
950            3 => ".xyz",
951            _ => "",
952        });
953    }
954    ArgWrap::None
955}
956
957/// Reverse of [`widen`] — pick the smaller vec when both args are vecs.
958/// Used by the broadcast pass so a `max(vec3, vec4)` call truncates the
959/// vec4 to vec3 (HLSL semantics), not the other way round.
960fn narrower(a: WgslType, b: WgslType) -> WgslType {
961    match (a, b) {
962        (WgslType::Unknown, x) | (x, WgslType::Unknown) => x,
963        (x, y) if x == y => x,
964        (WgslType::Vec2F, _) | (_, WgslType::Vec2F) => WgslType::Vec2F,
965        (WgslType::Vec3F, _) | (_, WgslType::Vec3F) => WgslType::Vec3F,
966        (WgslType::Vec4F, _) | (_, WgslType::Vec4F) => WgslType::Vec4F,
967        _ => a,
968    }
969}
970
971fn split_top_level_commas(s: &str) -> Vec<&str> {
972    let bytes = s.as_bytes();
973    let mut out = Vec::new();
974    let mut depth_paren = 0i32;
975    let mut depth_angle = 0i32;
976    let mut start = 0usize;
977    for (i, &b) in bytes.iter().enumerate() {
978        match b {
979            b'(' => depth_paren += 1,
980            b')' => depth_paren -= 1,
981            b'<' => depth_angle += 1,
982            b'>' => depth_angle -= 1,
983            b',' if depth_paren == 0 && depth_angle == 0 => {
984                out.push(&s[start..i]);
985                start = i + 1;
986            }
987            _ => {}
988        }
989    }
990    out.push(&s[start..]);
991    out
992}
993
994// ---------------------------------------------------------------------------
995// Pass: assignment-target type coercion
996// ---------------------------------------------------------------------------
997
998/// Walk the source for `<bare_ident> <op>= <expr>;` statements (where
999/// `<op>` is empty for plain assignment or one of `+ - * /` for compound
1000/// assignment). When the inferred type of `<expr>` doesn't match the
1001/// declared type of `<bare_ident>`, inject the HLSL implicit conversion
1002/// (broadcast or truncate).
1003///
1004/// This is the assignment-statement analogue of [`inject_truncations`],
1005/// which handles `var X: T = ...` declarations. The dominant motivating
1006/// case is `ret += tex3D(...)*0.5;` where `ret` is vec3 in the wrapper
1007/// preamble and the tex3D fallback returns vec4 — WGSL refuses with
1008/// `InvalidBinaryOperandTypes`.
1009///
1010/// Conservative scope:
1011/// - LHS must be a bare identifier (no swizzle, no array index).
1012/// - LHS must resolve to a known vec or scalar type in `table`.
1013/// - RHS must infer to a known type that differs from LHS.
1014/// - Compound `*=` / `/=` only fire when RHS is scalar→vec (broadcast)
1015///   or larger vec→smaller (truncate) — leaving matrix-mul ambiguity
1016///   well alone.
1017pub fn inject_assignment_coercions(src: &str, table: &SymbolTable) -> String {
1018    let bytes = src.as_bytes();
1019    let mut out = String::with_capacity(src.len() + 64);
1020    let mut i = 0usize;
1021    let mut at_stmt_start = true;
1022    let mut paren = 0i32;
1023
1024    while i < bytes.len() {
1025        // Track comments — never start a statement inside a comment.
1026        if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
1027            while i < bytes.len() && bytes[i] != b'\n' {
1028                out.push(bytes[i] as char);
1029                i += 1;
1030            }
1031            continue;
1032        }
1033        if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
1034            let s = i;
1035            i += 2;
1036            while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
1037                i += 1;
1038            }
1039            if i + 1 < bytes.len() {
1040                i += 2;
1041            }
1042            out.push_str(&src[s..i]);
1043            continue;
1044        }
1045
1046        // Track paren depth — only attempt assignment matching at depth 0.
1047        match bytes[i] {
1048            b'(' => paren += 1,
1049            b')' => paren -= 1,
1050            _ => {}
1051        }
1052
1053        if !at_stmt_start || paren != 0 || !bytes[i].is_ascii_alphabetic() && bytes[i] != b'_' {
1054            // Update statement-start tracking and continue. Whitespace
1055            // leaves the state untouched; statement boundaries set it
1056            // back to true; everything else falsifies it.
1057            if !bytes[i].is_ascii_whitespace() {
1058                at_stmt_start = matches!(bytes[i], b';' | b'{' | b'}');
1059            }
1060            out.push(bytes[i] as char);
1061            i += 1;
1062            continue;
1063        }
1064
1065        // Try to parse `<ident>(.<swizzle>)? <op>= <expr>;`.
1066        let id_start = i;
1067        let mut p = i;
1068        while p < bytes.len() && (bytes[p].is_ascii_alphanumeric() || bytes[p] == b'_') {
1069            p += 1;
1070        }
1071        let id_end = p;
1072        let name = &src[id_start..id_end];
1073        // Optionally consume a swizzle suffix (`.x`, `.xy`, `.xyz`,
1074        // `.xyzw`). Lets us coerce `uv2.x = tex2D(...)` (a real MD2 pattern
1075        // in midgitstraights / suksma-neck) by treating the LHS as f32
1076        // when the swizzle is single-component, vecN otherwise. The
1077        // accessor characters mirror WGSL's set (`xyzw` + the legacy
1078        // `rgba` aliases normalised away earlier in the pipeline — kept
1079        // here as defence in depth in case a future pass reintroduces them).
1080        let mut sw_len = 0usize;
1081        if p < bytes.len() && bytes[p] == b'.' && p + 1 < bytes.len() {
1082            let sw_start = p + 1;
1083            let mut q = sw_start;
1084            while q < bytes.len()
1085                && matches!(
1086                    bytes[q],
1087                    b'x' | b'y' | b'z' | b'w' | b'r' | b'g' | b'b' | b'a'
1088                )
1089            {
1090                q += 1;
1091            }
1092            if q > sw_start && q - sw_start <= 4 {
1093                sw_len = q - sw_start;
1094                p = q;
1095            }
1096        }
1097        // Skip whitespace.
1098        while p < bytes.len() && bytes[p].is_ascii_whitespace() {
1099            p += 1;
1100        }
1101        if p >= bytes.len() {
1102            out.push_str(&src[id_start..p]);
1103            i = p;
1104            at_stmt_start = false;
1105            continue;
1106        }
1107        // Look for assignment op: `=`, `+=`, `-=`, `*=`, `/=`.
1108        let (op_byte, op_len) = match bytes[p] {
1109            b'=' if bytes.get(p + 1) != Some(&b'=') => (None, 1),
1110            b'+' if bytes.get(p + 1) == Some(&b'=') => (Some(b'+'), 2),
1111            b'-' if bytes.get(p + 1) == Some(&b'=') => (Some(b'-'), 2),
1112            b'*' if bytes.get(p + 1) == Some(&b'=') => (Some(b'*'), 2),
1113            b'/' if bytes.get(p + 1) == Some(&b'=') => (Some(b'/'), 2),
1114            _ => {
1115                // Not an assignment statement — pass through.
1116                out.push_str(&src[id_start..p]);
1117                i = p;
1118                at_stmt_start = false;
1119                continue;
1120            }
1121        };
1122        let op_end = p + op_len;
1123        // Read RHS until `;` at depth 0.
1124        let rhs_start = op_end;
1125        let mut q = rhs_start;
1126        let mut dpar = 0i32;
1127        let mut dbr = 0i32;
1128        while q < bytes.len() {
1129            match bytes[q] {
1130                b'(' => dpar += 1,
1131                b')' => dpar -= 1,
1132                b'[' => dbr += 1,
1133                b']' => dbr -= 1,
1134                b';' if dpar == 0 && dbr == 0 => break,
1135                _ => {}
1136            }
1137            q += 1;
1138        }
1139        if q >= bytes.len() {
1140            out.push_str(&src[id_start..q]);
1141            i = q;
1142            continue;
1143        }
1144        let rhs_raw = &src[rhs_start..q];
1145        let rhs_trimmed = rhs_raw.trim();
1146
1147        // Look up LHS type. If unknown, pass through.
1148        let Some(base_ty) = table.lookup(name) else {
1149            out.push_str(&src[id_start..q]);
1150            i = q;
1151            at_stmt_start = false;
1152            continue;
1153        };
1154
1155        // If we consumed a swizzle suffix, the effective LHS type is
1156        // derived from the swizzle length (single-component = scalar,
1157        // N-component = vec_N). The underlying ident still needs to be a
1158        // vec — bail otherwise (writing to `f32.x` makes no sense).
1159        let lhs_ty = if sw_len == 0 {
1160            base_ty
1161        } else if base_ty.is_vec() {
1162            match sw_len {
1163                1 => WgslType::F32,
1164                2 => WgslType::Vec2F,
1165                3 => WgslType::Vec3F,
1166                4 => WgslType::Vec4F,
1167                _ => {
1168                    out.push_str(&src[id_start..q]);
1169                    i = q;
1170                    at_stmt_start = false;
1171                    continue;
1172                }
1173            }
1174        } else {
1175            out.push_str(&src[id_start..q]);
1176            i = q;
1177            at_stmt_start = false;
1178            continue;
1179        };
1180
1181        // Don't bother for non-scalar/non-vec types (matrix assignment is
1182        // outside our scope).
1183        if !lhs_ty.is_scalar() && !lhs_ty.is_vec() {
1184            out.push_str(&src[id_start..q]);
1185            i = q;
1186            at_stmt_start = false;
1187            continue;
1188        }
1189
1190        let rhs_ty = table.infer_expr_type(rhs_trimmed);
1191
1192        // Emit `<ident>(.swizzle)? <op>= ` verbatim.
1193        out.push_str(&src[id_start..op_end]);
1194        out.push(' ');
1195
1196        // Coercion handles compound ops on scalar LHS and swizzle-LHS
1197        // scalars. The AdamFX CollaborationFX preset writes
1198        // `dx += GetPixel(...) - GetPixel(...);` with `dx: f32` — without
1199        // this, the `op_byte.is_none()` guard let the vec3 RHS slide
1200        // through to naga, which rejected the InvalidStoreTypes. The
1201        // midgitstraights / suksma-neck preset writes `uv2.x = 1 -
1202        // tex2D(...);` with `uv2: vec2<f32>` — the swizzle LHS path
1203        // narrows the LHS to f32 and truncates the vec4 RHS to `.x`.
1204        let coerced = match (lhs_ty, rhs_ty) {
1205            // vec = scalar: broadcast. Only for plain `=` and `+=`/`-=`
1206            // (broadcast doesn't change `*=`/`/=` semantics on vec, but
1207            // we avoid them out of paranoia).
1208            (l, WgslType::F32) | (l, WgslType::I32)
1209                if l.is_vec() && matches!(op_byte, None | Some(b'+') | Some(b'-')) =>
1210            {
1211                Some(format!("{}({})", l.wgsl_name(), rhs_trimmed))
1212            }
1213            // vec_n = vec_m where m > n: truncate.
1214            (l, r) if l.is_vec() && r.is_vec() && vec_size(r) > vec_size(l) => {
1215                let sw = match vec_size(l) {
1216                    2 => ".xy",
1217                    3 => ".xyz",
1218                    _ => "",
1219                };
1220                Some(format!("({}){}", rhs_trimmed, sw))
1221            }
1222            // f32 = vec: truncate with `.x`. Now fires for compound ops
1223            // too — `dx += vec3` becomes `dx += (vec3).x`, matching the
1224            // HLSL implicit-truncation semantics.
1225            (l, r) if l.is_scalar() && r.is_vec() => Some(format!("({}).x", rhs_trimmed)),
1226            _ => None,
1227        };
1228
1229        if let Some(new_rhs) = coerced {
1230            out.push_str(&new_rhs);
1231        } else {
1232            out.push_str(rhs_raw);
1233        }
1234        out.push(';');
1235        i = q + 1;
1236        at_stmt_start = true;
1237    }
1238
1239    out
1240}
1241
1242// ---------------------------------------------------------------------------
1243// Pass: scalar→vec implicit broadcast in declarations + f32←vec truncation
1244// ---------------------------------------------------------------------------
1245
1246/// Walk `var X: TYPE = <expr>;` (or `let`) declarations and fix HLSL-
1247/// implicit-conversion mismatches between LHS type and RHS type:
1248///
1249/// - **`f32 = vec`**: rewrite RHS to `(<expr>).x` (HLSL truncation).
1250/// - **`vec = scalar`**: rewrite RHS to `vecN<f32>(<expr>)` (HLSL broadcast).
1251///
1252/// The pass is intentionally conservative: only fires when the LHS type
1253/// is known and the inferred RHS type is the *other* of the two cases
1254/// above. Anything ambiguous (`Unknown` RHS) passes through.
1255///
1256/// Why this exists: MD2 user shaders frequently write
1257/// `float3 bg = pow(length(dz), 0.7)*2 + GetBlur1(uv).y*2;` — RHS is
1258/// scalar, LHS is vec3; HLSL silently broadcasts, WGSL rejects with
1259/// `the type of bg is expected to be vec3<f32> but got f32`. Mirror
1260/// case: `float lum = GetPixel(uv) * c.x;` — RHS vec3, LHS f32.
1261pub fn inject_truncations(src: &str, table: &SymbolTable) -> String {
1262    let bytes = src.as_bytes();
1263    let mut out = String::with_capacity(src.len() + 64);
1264    let mut i = 0;
1265
1266    while i < bytes.len() {
1267        // Skip comments to avoid false positives in disabled code blocks.
1268        if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
1269            while i < bytes.len() && bytes[i] != b'\n' {
1270                out.push(bytes[i] as char);
1271                i += 1;
1272            }
1273            continue;
1274        }
1275
1276        let kw_len = match keyword_at(bytes, i, &["var", "let"]) {
1277            Some(n) => n,
1278            None => {
1279                out.push(bytes[i] as char);
1280                i += 1;
1281                continue;
1282            }
1283        };
1284
1285        // Try to parse `var/let NAME: f32 = <expr>;`. Anything that
1286        // doesn't fit this exact shape passes through verbatim.
1287        let kw_end = i + kw_len;
1288        let mut p = kw_end;
1289        while p < bytes.len() && bytes[p].is_ascii_whitespace() {
1290            p += 1;
1291        }
1292        let name_start = p;
1293        while p < bytes.len() && (bytes[p].is_ascii_alphanumeric() || bytes[p] == b'_') {
1294            p += 1;
1295        }
1296        if p == name_start {
1297            out.push_str(&src[i..kw_end]);
1298            i = kw_end;
1299            continue;
1300        }
1301        let after_name = p;
1302        while p < bytes.len() && bytes[p].is_ascii_whitespace() {
1303            p += 1;
1304        }
1305        if p >= bytes.len() || bytes[p] != b':' {
1306            out.push_str(&src[i..after_name]);
1307            i = after_name;
1308            continue;
1309        }
1310        p += 1;
1311        // Read type until `=` or `;`.
1312        let ty_start = p;
1313        while p < bytes.len() && bytes[p] != b'=' && bytes[p] != b';' {
1314            p += 1;
1315        }
1316        let ty_str = src[ty_start..p].trim();
1317        let lhs_ty = WgslType::from_decl_str(ty_str);
1318        // Only act on scalar or vec LHS — matrix/bool/unknown pass through.
1319        if !matches!(
1320            lhs_ty,
1321            WgslType::F32 | WgslType::I32 | WgslType::Vec2F | WgslType::Vec3F | WgslType::Vec4F
1322        ) {
1323            out.push_str(&src[i..p]);
1324            i = p;
1325            continue;
1326        }
1327        if p >= bytes.len() || bytes[p] != b'=' {
1328            // No initialiser — passthrough.
1329            out.push_str(&src[i..p]);
1330            i = p;
1331            continue;
1332        }
1333        let eq_at = p;
1334        p += 1;
1335        // RHS until `;` at depth 0.
1336        let rhs_start = p;
1337        let mut depth_paren = 0i32;
1338        let mut depth_bracket = 0i32;
1339        while p < bytes.len() {
1340            match bytes[p] {
1341                b'(' => depth_paren += 1,
1342                b')' => depth_paren -= 1,
1343                b'[' => depth_bracket += 1,
1344                b']' => depth_bracket -= 1,
1345                b';' if depth_paren == 0 && depth_bracket == 0 => break,
1346                _ => {}
1347            }
1348            p += 1;
1349        }
1350        if p >= bytes.len() {
1351            // Unterminated — bail.
1352            out.push_str(&src[i..]);
1353            return out;
1354        }
1355        let rhs = &src[rhs_start..p];
1356        let rhs_trimmed = rhs.trim();
1357        let rhs_ty = table.infer_expr_type(rhs_trimmed);
1358
1359        // Emit the prefix (`var NAME: TYPE`) verbatim.
1360        out.push_str(&src[i..=eq_at]);
1361
1362        match (lhs_ty, rhs_ty) {
1363            // f32/i32 = vec: truncate with `.x`.
1364            (l, r) if l.is_scalar() && r.is_vec() => {
1365                out.push_str(" (");
1366                out.push_str(rhs_trimmed);
1367                out.push_str(").x");
1368            }
1369            // vec = scalar: broadcast via constructor.
1370            (l, WgslType::F32) | (l, WgslType::I32) if l.is_vec() => {
1371                out.push(' ');
1372                out.push_str(l.wgsl_name());
1373                out.push('(');
1374                out.push_str(rhs_trimmed);
1375                out.push(')');
1376            }
1377            // vecN = vecM where N < M: truncate via swizzle.
1378            (l, r) if l.is_vec() && r.is_vec() && vec_size(r) > vec_size(l) => {
1379                let swizzle = match vec_size(l) {
1380                    2 => ".xy",
1381                    3 => ".xyz",
1382                    _ => "",
1383                };
1384                out.push_str(" (");
1385                out.push_str(rhs_trimmed);
1386                out.push(')');
1387                out.push_str(swizzle);
1388            }
1389            _ => {
1390                out.push_str(rhs);
1391            }
1392        }
1393        out.push(';');
1394        i = p + 1;
1395    }
1396
1397    out
1398}
1399
1400// ---------------------------------------------------------------------------
1401// Pass: swizzle-LHS assignment rewrite
1402// ---------------------------------------------------------------------------
1403
1404/// Rewrite `target.<swizzle> <op>= <rhs>;` into a full-vector
1405/// reconstruction. WGSL refuses multi-component swizzles on the LHS of an
1406/// assignment (`invalid left-hand side of assignment`); HLSL allows them
1407/// freely. The dominant cases in `test-presets-200/`:
1408///
1409/// - `ret.xy *= diff;`
1410/// - `ret.xyz = tex2D(...).xyz;`
1411/// - `ret.zy /= diff2;`
1412///
1413/// Rewrite shape (for `target: vec3<f32>`):
1414///
1415/// `ret.xy = expr;`  →  `ret = vec3<f32>((expr).x, (expr).y, ret.z);`
1416/// `ret.xy *= expr;` →  `ret = vec3<f32>(ret.x * (expr).x, ret.y * (expr).y, ret.z);`
1417///
1418/// Constraints:
1419/// - `target` must resolve to a known vec3/vec4 in `table`. Skip otherwise.
1420/// - Swizzle length must be 2..=target_size, all distinct, all in `{x,y,z,w}`.
1421/// - Reordering swizzles (`yx`, `zy`) are handled — each target component
1422///   picks the matching swizzle slot.
1423/// - Scalar RHS is broadcast (each lane uses the same `(expr)`).
1424/// - RHS with unknown type is treated as a vec matching the swizzle width
1425///   (the dominant case in real shaders); the resulting WGSL still
1426///   validates because positional swizzle access on a vec is well-typed.
1427pub fn inject_swizzle_assignments(src: &str, table: &SymbolTable) -> String {
1428    use regex::Regex;
1429    use std::sync::LazyLock;
1430
1431    static SWZ_RE: LazyLock<Regex> = LazyLock::new(|| {
1432        // Line-anchored: only fires at statement start (after whitespace
1433        // or the wrapper's body prefix). Captures:
1434        //   1 = indent, 2 = target ident, 3 = swizzle, 4 = op, 5 = rhs
1435        Regex::new(
1436            r"(?m)^([\t ]*)([A-Za-z_][A-Za-z0-9_]*)\.([xyzwrgba]{2,4})\s*([+\-*/]?=)\s*([^;]+);",
1437        )
1438        .unwrap()
1439    });
1440
1441    SWZ_RE
1442        .replace_all(src, |caps: &regex::Captures| {
1443            let indent = &caps[1];
1444            let target = &caps[2];
1445            let swizzle = &caps[3];
1446            let op = &caps[4];
1447            let rhs = caps[5].trim();
1448
1449            // Normalise rgba/xyzw — accept either MD2 colour spelling but
1450            // emit canonical xyzw inside fs_main (the wrapper uses xyzw).
1451            let swizzle_xyzw = normalise_swizzle(swizzle);
1452
1453            // All components must be unique — duplicate components on the
1454            // LHS are an HLSL error too; bail rather than fabricate.
1455            if !all_unique(&swizzle_xyzw) {
1456                return caps[0].to_string();
1457            }
1458
1459            // Lookup target type. Only handle vec3/vec4; bail otherwise.
1460            let target_ty = match table.lookup(target) {
1461                Some(WgslType::Vec3F) => WgslType::Vec3F,
1462                Some(WgslType::Vec4F) => WgslType::Vec4F,
1463                _ => return caps[0].to_string(),
1464            };
1465            let target_size = vec_size(target_ty);
1466
1467            // Swizzle width must not exceed target's width.
1468            if swizzle_xyzw.len() > target_size {
1469                return caps[0].to_string();
1470            }
1471
1472            // Compute new value for each component of the target vec.
1473            let comps = match target_size {
1474                3 => &['x', 'y', 'z'][..],
1475                4 => &['x', 'y', 'z', 'w'][..],
1476                _ => return caps[0].to_string(),
1477            };
1478            let rhs_is_scalar = table.infer_expr_type(rhs).is_scalar();
1479            let mut lane_exprs: Vec<String> = Vec::with_capacity(target_size);
1480            for &c in comps {
1481                if let Some(pos) = swizzle_xyzw.iter().position(|&s| s == c) {
1482                    // Lane in RHS at this swizzle position. The RHS is
1483                    // indexed positionally (lane 0 = .x, lane 1 = .y,
1484                    // …) — NOT by the swizzle letter at this position,
1485                    // which equals the target component `c` itself.
1486                    let lane_letter = match pos {
1487                        0 => 'x',
1488                        1 => 'y',
1489                        2 => 'z',
1490                        3 => 'w',
1491                        _ => return caps[0].to_string(),
1492                    };
1493                    let rhs_lane = if rhs_is_scalar || swizzle_xyzw.len() == 1 {
1494                        // Scalar RHS — broadcast unchanged value.
1495                        format!("({rhs})")
1496                    } else {
1497                        format!("({rhs}).{lane_letter}")
1498                    };
1499                    let new_val = match op {
1500                        "=" => rhs_lane,
1501                        "+=" => format!("{target}.{c} + {rhs_lane}"),
1502                        "-=" => format!("{target}.{c} - {rhs_lane}"),
1503                        "*=" => format!("{target}.{c} * {rhs_lane}"),
1504                        "/=" => format!("{target}.{c} / {rhs_lane}"),
1505                        _ => return caps[0].to_string(),
1506                    };
1507                    lane_exprs.push(new_val);
1508                } else {
1509                    // Component not in swizzle — keep unchanged.
1510                    lane_exprs.push(format!("{target}.{c}"));
1511                }
1512            }
1513
1514            format!(
1515                "{indent}{target} = {ty}({args});",
1516                ty = target_ty.wgsl_name(),
1517                args = lane_exprs.join(", ")
1518            )
1519        })
1520        .to_string()
1521}
1522
1523/// Map `r/g/b/a` to `x/y/z/w` while preserving identity for `x/y/z/w`.
1524/// MD2 user shaders mix the two conventions freely; the codegen wrapper
1525/// always uses `xyzw` for its locals, so canonicalising on emit keeps
1526/// downstream lane lookups stable.
1527fn normalise_swizzle(s: &str) -> Vec<char> {
1528    s.chars()
1529        .map(|c| match c {
1530            'r' => 'x',
1531            'g' => 'y',
1532            'b' => 'z',
1533            'a' => 'w',
1534            other => other,
1535        })
1536        .collect()
1537}
1538
1539fn all_unique(letters: &[char]) -> bool {
1540    let mut seen = [false; 4];
1541    for &c in letters {
1542        let idx = match c {
1543            'x' => 0,
1544            'y' => 1,
1545            'z' => 2,
1546            'w' => 3,
1547            _ => return false,
1548        };
1549        if seen[idx] {
1550            return false;
1551        }
1552        seen[idx] = true;
1553    }
1554    true
1555}
1556
1557#[cfg(test)]
1558mod tests {
1559    use super::*;
1560
1561    #[test]
1562    fn symbol_table_picks_up_var_decls() {
1563        let src = "var foo: vec3<f32> = vec3<f32>(0.0); let bar: f32 = 1.0;";
1564        let t = SymbolTable::from_source(src);
1565        assert_eq!(t.lookup("foo"), Some(WgslType::Vec3F));
1566        assert_eq!(t.lookup("bar"), Some(WgslType::F32));
1567    }
1568
1569    #[test]
1570    fn symbol_table_seeds_wrapper_locals() {
1571        // `uv` is in the wrapper preamble — it must resolve even with no
1572        // user-side `var` for it.
1573        let t = SymbolTable::from_source("// nothing here");
1574        assert_eq!(t.lookup("uv"), Some(WgslType::Vec2F));
1575        assert_eq!(t.lookup("texsize"), Some(WgslType::Vec4F));
1576        assert_eq!(t.lookup("q1"), Some(WgslType::F32));
1577        assert_eq!(t.lookup("M_PI_2"), Some(WgslType::F32));
1578    }
1579
1580    #[test]
1581    fn infer_expr_numeric_literal() {
1582        let t = SymbolTable::from_source("");
1583        assert_eq!(t.infer_expr_type("1.0"), WgslType::F32);
1584        assert_eq!(t.infer_expr_type("-3.14"), WgslType::F32);
1585        assert_eq!(t.infer_expr_type("0"), WgslType::F32);
1586    }
1587
1588    #[test]
1589    fn infer_expr_known_helpers() {
1590        let t = SymbolTable::from_source("");
1591        assert_eq!(t.infer_expr_type("GetPixel(uv)"), WgslType::Vec3F);
1592        assert_eq!(t.infer_expr_type("GetBlur1(uv)"), WgslType::Vec3F);
1593        assert_eq!(t.infer_expr_type("lum(ret)"), WgslType::F32);
1594        assert_eq!(
1595            t.infer_expr_type("textureSample(sampler_main_texture, sampler_main, uv)"),
1596            WgslType::Vec4F
1597        );
1598    }
1599
1600    #[test]
1601    fn infer_expr_constructor() {
1602        let t = SymbolTable::from_source("");
1603        assert_eq!(t.infer_expr_type("vec3<f32>(1, 0, 0)"), WgslType::Vec3F);
1604        assert_eq!(t.infer_expr_type("vec4<f32>(c)"), WgslType::Vec4F);
1605    }
1606
1607    #[test]
1608    fn infer_expr_swizzle_narrows_to_scalar() {
1609        let t = SymbolTable::from_source("var c: vec4<f32> = vec4<f32>(1);");
1610        assert_eq!(t.infer_expr_type("c.x"), WgslType::F32);
1611        assert_eq!(t.infer_expr_type("c.xy"), WgslType::Vec2F);
1612        assert_eq!(t.infer_expr_type("c.xyz"), WgslType::Vec3F);
1613    }
1614
1615    #[test]
1616    fn infer_expr_binop_widens_to_vec() {
1617        let t = SymbolTable::from_source("var c: vec4<f32> = vec4<f32>(1);");
1618        // `GetPixel(uv) * c.x` → vec3 * f32 = vec3.
1619        assert_eq!(t.infer_expr_type("GetPixel(uv) * c.x"), WgslType::Vec3F);
1620    }
1621
1622    #[test]
1623    fn broadcast_clamp_with_scalar_bounds() {
1624        let t = SymbolTable::from_source("");
1625        let src = "ret = clamp(GetBlur1(uv), 0.0, 1.0);";
1626        let out = inject_broadcasts(src, &t);
1627        assert!(
1628            out.contains("clamp(GetBlur1(uv), vec3<f32>(0.0), vec3<f32>(1.0))"),
1629            "got: {out}"
1630        );
1631    }
1632
1633    #[test]
1634    fn broadcast_pow_with_scalar_exponent() {
1635        let t = SymbolTable::from_source("");
1636        let src = "ret = pow(GetPixel(uv), 0.5);";
1637        let out = inject_broadcasts(src, &t);
1638        assert!(
1639            out.contains("pow(GetPixel(uv), vec3<f32>(0.5))"),
1640            "got: {out}"
1641        );
1642    }
1643
1644    #[test]
1645    fn broadcast_mix_with_scalar_lerp_factor() {
1646        let t = SymbolTable::from_source("");
1647        let src = "ret = mix(a, b, 0.3);";
1648        // Without context (a, b unknown), no rewrite — the pass is
1649        // conservative.
1650        let out = inject_broadcasts(src, &t);
1651        assert_eq!(out, src);
1652
1653        // Now with vec3 a and b in scope:
1654        let t = SymbolTable::from_source(
1655            "var a: vec3<f32> = vec3<f32>(0); var b: vec3<f32> = vec3<f32>(1);",
1656        );
1657        let out = inject_broadcasts(src, &t);
1658        assert!(out.contains("mix(a, b, vec3<f32>(0.3))"), "got: {out}");
1659    }
1660
1661    #[test]
1662    fn broadcast_truncates_larger_vec_to_smaller() {
1663        // Real preset pattern: `ret = max(ret_vec3, tex2D(...))`. HLSL
1664        // truncates the vec4 result to vec3; WGSL needs an explicit `.xyz`.
1665        let t = SymbolTable::from_source("var ret: vec3<f32> = vec3<f32>(0);");
1666        let src = "ret = max(ret, textureSample(t, s, uv));";
1667        let out = inject_broadcasts(src, &t);
1668        assert!(
1669            out.contains("max(ret, (textureSample(t, s, uv)).xyz)"),
1670            "got: {out}"
1671        );
1672    }
1673
1674    #[test]
1675    fn broadcast_skipped_when_all_args_scalar() {
1676        let t = SymbolTable::from_source("");
1677        let src = "var x: f32 = clamp(0.5, 0.0, 1.0);";
1678        let out = inject_broadcasts(src, &t);
1679        assert_eq!(out, src);
1680    }
1681
1682    #[test]
1683    fn truncation_f32_eq_vec3_inserts_dot_x() {
1684        let t = SymbolTable::from_source("");
1685        let src = "var gx1: f32 = GetPixel(uv) + GetBlur1(uv);";
1686        let out = inject_truncations(src, &t);
1687        assert!(
1688            out.contains("var gx1: f32 = (GetPixel(uv) + GetBlur1(uv)).x;"),
1689            "got: {out}"
1690        );
1691    }
1692
1693    #[test]
1694    fn truncation_skipped_when_rhs_already_scalar() {
1695        let t = SymbolTable::from_source("");
1696        let src = "var x: f32 = 1.0 + 2.0;";
1697        let out = inject_truncations(src, &t);
1698        assert_eq!(out, src);
1699    }
1700
1701    #[test]
1702    fn truncation_skipped_when_lhs_is_vec() {
1703        let t = SymbolTable::from_source("");
1704        let src = "var v: vec3<f32> = GetPixel(uv);";
1705        let out = inject_truncations(src, &t);
1706        assert_eq!(out, src);
1707    }
1708
1709    #[test]
1710    fn broadcast_vec3_eq_scalar_wraps_in_constructor() {
1711        // Real preset pattern: `float3 bg = pow(length(dz), 0.7)*2 + GetBlur1(uv).y*2;`
1712        // RHS is f32, LHS is vec3 — must inject `vec3<f32>(<rhs>)`.
1713        let t = SymbolTable::from_source("var dz: vec3<f32> = vec3<f32>(0);");
1714        let src = "var bg: vec3<f32> = pow(length(dz), 0.7)*2 + GetBlur1(uv).y*2;";
1715        let out = inject_truncations(src, &t);
1716        assert!(
1717            out.contains(
1718                "var bg: vec3<f32> = vec3<f32>(pow(length(dz), 0.7)*2 + GetBlur1(uv).y*2);"
1719            ),
1720            "got: {out}"
1721        );
1722    }
1723
1724    #[test]
1725    fn truncation_vec3_eq_vec4_appends_xyz() {
1726        // Real preset pattern: `float3 ret2 = tex2D(sampler_main, uv);`
1727        // where tex2D returns vec4 — must inject `.xyz`.
1728        let t = SymbolTable::from_source("");
1729        let src = "var ret2: vec3<f32> = textureSample(t, s, uv);";
1730        let out = inject_truncations(src, &t);
1731        assert!(
1732            out.contains("var ret2: vec3<f32> = (textureSample(t, s, uv)).xyz;"),
1733            "got: {out}"
1734        );
1735    }
1736
1737    #[test]
1738    fn truncation_vec2_eq_vec4_appends_xy() {
1739        let t = SymbolTable::from_source("");
1740        let src = "var sam: vec2<f32> = textureSample(t, s, uv);";
1741        let out = inject_truncations(src, &t);
1742        assert!(
1743            out.contains("var sam: vec2<f32> = (textureSample(t, s, uv)).xy;"),
1744            "got: {out}"
1745        );
1746    }
1747
1748    #[test]
1749    fn broadcast_skipped_when_rhs_already_vec() {
1750        let t = SymbolTable::from_source("");
1751        let src = "var v: vec3<f32> = GetPixel(uv);";
1752        let out = inject_truncations(src, &t);
1753        assert_eq!(out, src);
1754    }
1755
1756    #[test]
1757    fn truncation_skipped_on_unknown_rhs_type() {
1758        let t = SymbolTable::from_source("");
1759        // RHS references an unknown identifier — we can't be sure of the
1760        // type, so don't touch it.
1761        let src = "var x: f32 = some_user_function(uv);";
1762        let out = inject_truncations(src, &t);
1763        assert_eq!(out, src);
1764    }
1765
1766    // ----------------------------------------------------------------
1767    // Swizzle-LHS assignment rewrite
1768    // ----------------------------------------------------------------
1769
1770    #[test]
1771    fn swizzle_xy_assignment_on_vec3_target() {
1772        // `ret` is vec3<f32> per the wrapper preamble. `ret.xy = expr;`
1773        // rebuilds the full vec3 with the .x/.y lanes from expr and the
1774        // .z lane unchanged.
1775        let t = SymbolTable::from_source("");
1776        let src = "ret.xy = diff;";
1777        let out = inject_swizzle_assignments(src, &t);
1778        assert_eq!(
1779            out, "ret = vec3<f32>((diff).x, (diff).y, ret.z);",
1780            "got: {out}"
1781        );
1782    }
1783
1784    #[test]
1785    fn swizzle_xyz_full_replace_on_vec3_target() {
1786        let t = SymbolTable::from_source("");
1787        let src = "ret.xyz = tex2D(s, uv).xyz;";
1788        let out = inject_swizzle_assignments(src, &t);
1789        // Whole vec replaced — `.z` lane sourced from `(rhs).z`, not the
1790        // unchanged target.
1791        assert_eq!(
1792            out,
1793            "ret = vec3<f32>((tex2D(s, uv).xyz).x, (tex2D(s, uv).xyz).y, (tex2D(s, uv).xyz).z);",
1794        );
1795    }
1796
1797    #[test]
1798    fn swizzle_compound_mul_assign() {
1799        let t = SymbolTable::from_source("");
1800        let src = "ret.xy *= diff;";
1801        let out = inject_swizzle_assignments(src, &t);
1802        // Compound: each lane = target.<comp> * rhs.<lane>.
1803        assert_eq!(
1804            out,
1805            "ret = vec3<f32>(ret.x * (diff).x, ret.y * (diff).y, ret.z);",
1806        );
1807    }
1808
1809    #[test]
1810    fn swizzle_reordered_zy() {
1811        // `ret.zy = expr;` puts expr.x into ret.z and expr.y into ret.y.
1812        let t = SymbolTable::from_source("");
1813        let src = "ret.zy = pair;";
1814        let out = inject_swizzle_assignments(src, &t);
1815        // x unchanged, y = pair.y, z = pair.x.
1816        assert_eq!(out, "ret = vec3<f32>(ret.x, (pair).y, (pair).x);",);
1817    }
1818
1819    #[test]
1820    fn swizzle_rgba_normalised_to_xyzw() {
1821        let t = SymbolTable::from_source("");
1822        let src = "ret.rg = pair;";
1823        let out = inject_swizzle_assignments(src, &t);
1824        // `rg` maps to `xy`. Same emit shape as `ret.xy`.
1825        assert_eq!(out, "ret = vec3<f32>((pair).x, (pair).y, ret.z);",);
1826    }
1827
1828    #[test]
1829    fn swizzle_skipped_when_target_unknown() {
1830        // Unknown identifier — must NOT rewrite.
1831        let t = SymbolTable::from_source("");
1832        let src = "user_var.xy = stuff;";
1833        let out = inject_swizzle_assignments(src, &t);
1834        assert_eq!(out, src);
1835    }
1836
1837    #[test]
1838    fn swizzle_skipped_on_single_component() {
1839        // Single-component swizzles like `ret.x = foo;` are valid WGSL —
1840        // leave them alone.
1841        let t = SymbolTable::from_source("");
1842        let src = "ret.z = stuff;";
1843        let out = inject_swizzle_assignments(src, &t);
1844        assert_eq!(out, src);
1845    }
1846
1847    #[test]
1848    fn swizzle_skipped_when_duplicate_components() {
1849        // `ret.xx = foo` is meaningless on the LHS — bail.
1850        let t = SymbolTable::from_source("");
1851        let src = "ret.xx = pair;";
1852        let out = inject_swizzle_assignments(src, &t);
1853        assert_eq!(out, src);
1854    }
1855
1856    #[test]
1857    fn swizzle_div_compound_on_vec3() {
1858        let t = SymbolTable::from_source("");
1859        let src = "ret.zy /= diff2;";
1860        let out = inject_swizzle_assignments(src, &t);
1861        // x unchanged, y = ret.y / diff2.y, z = ret.z / diff2.x.
1862        assert_eq!(
1863            out,
1864            "ret = vec3<f32>(ret.x, ret.y / (diff2).y, ret.z / (diff2).x);",
1865        );
1866    }
1867
1868    #[test]
1869    fn swizzle_xy_on_vec4_target() {
1870        // vec4 case: `ret4: vec4<f32>`, swizzle `xy`, only x/y change.
1871        let t = SymbolTable::from_source("var ret4: vec4<f32> = vec4<f32>(0.0);");
1872        let src = "ret4.xy = pair;";
1873        let out = inject_swizzle_assignments(src, &t);
1874        assert_eq!(out, "ret4 = vec4<f32>((pair).x, (pair).y, ret4.z, ret4.w);",);
1875    }
1876}