onedrop_hlsl/rewrite/
qa_stub.rs

1//! Pass: `_qa..._qh` stub injection.
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass: `_qa..._qh` stub injection
7// ---------------------------------------------------------------------------
8
9/// `_qa`/`_qb`/.../`_qh` are MD2 runtime-defined aliases that pack the 32
10/// `q1..q32` evaluator channels into 8 × `float4` (`_qa = float4(q1, q2, q3, q4)`,
11/// `_qb = float4(q5, …)`, etc.). User comp shaders read them directly,
12/// typically via `mul(uv, float2x2(_qb))` for kaleidoscope rotation.
13///
14/// The standalone HLSL parser never sees their declarations because they
15/// live in the host-provided prelude — naga rejects with
16/// `no definition in scope for identifier: '_qb'`. We stub the full
17/// `_qa..._qh` family — many presets reference `_qb` and a long tail of
18/// the others.
19///
20/// For each referenced identifier we:
21///
22/// - prepend `float4 _qX = float4(0.0, 0.0, 0.0, 0.0);` at the start of
23///   `shader_body { ... }` so any later read picks up a defined value, and
24/// - expand every `float2x2(_qX)` call into the four-scalar constructor
25///   form `float2x2(_qX.x, _qX.y, _qX.z, _qX.w)` so it survives
26///   `replace_types` into a valid `mat2x2<f32>(...)` call (WGSL refuses
27///   the single-vec4 form).
28///
29/// If no `_qX` identifier is referenced, returns the input unchanged.
30pub(crate) fn inject_qa_stub(src: &str) -> String {
31    let Ok(tu) = parse_hlsl(src) else {
32        return src.to_string();
33    };
34    let mut uses = QaUses::default();
35    if let Some(body) = &tu.shader_body {
36        scan_block_for_qa(body, &mut uses);
37    }
38    for item in &tu.items {
39        if let Item::Function(f) = item {
40            scan_block_for_qa(&f.body, &mut uses);
41        }
42    }
43    if uses.referenced_mask == 0 {
44        return src.to_string();
45    }
46    let Some(body) = &tu.shader_body else {
47        return src.to_string();
48    };
49
50    let mut edits: Vec<TextEdit> = uses
51        .mat_call_spans
52        .iter()
53        .map(|(sp, letter)| TextEdit {
54            start: sp.start,
55            end: sp.end,
56            replacement: format!(
57                "float2x2(_q{letter}.x, _q{letter}.y, _q{letter}.z, _q{letter}.w)"
58            ),
59        })
60        .collect();
61
62    if let Some(open) = find_brace_after(src, body.span.start) {
63        let mut stub = String::new();
64        for (idx, letter) in QA_LETTERS.iter().enumerate() {
65            if uses.referenced_mask & (1u8 << idx) != 0 {
66                stub.push_str(&format!(" float4 _q{letter} = float4(0.0, 0.0, 0.0, 0.0);"));
67            }
68        }
69        edits.push(TextEdit {
70            start: open,
71            end: open,
72            replacement: stub,
73        });
74    }
75    apply_edits(src, &mut edits)
76}
77
78/// Letters `a..h` map to the 8 four-channel groups packing `q1..q32`.
79const QA_LETTERS: [char; 8] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'];
80
81/// Returns `Some(bit_index)` if `name` is one of `_qa..._qh`.
82fn qa_letter_index(name: &str) -> Option<usize> {
83    let bytes = name.as_bytes();
84    if bytes.len() != 3 || bytes[0] != b'_' || bytes[1] != b'q' {
85        return None;
86    }
87    let c = bytes[2];
88    if (b'a'..=b'h').contains(&c) {
89        Some((c - b'a') as usize)
90    } else {
91        None
92    }
93}
94
95#[derive(Default)]
96struct QaUses {
97    /// Bit-set of referenced `_qa..._qh` identifiers (bit 0 = `_qa`).
98    referenced_mask: u8,
99    /// Each entry is `(call_span, ident_letter)` for a `float2x2(_qX)` call
100    /// to expand into the four-scalar form.
101    mat_call_spans: Vec<(Span, char)>,
102}
103
104fn scan_block_for_qa(b: &Block, out: &mut QaUses) {
105    for s in &b.stmts {
106        scan_stmt_for_qa(s, out);
107    }
108}
109
110fn scan_stmt_for_qa(s: &Stmt, out: &mut QaUses) {
111    match s {
112        Stmt::LocalDecl(d) => {
113            if let Some(init) = &d.init {
114                scan_expr_for_qa(init, out);
115            }
116        }
117        Stmt::Assign(a) => {
118            scan_expr_for_qa(&a.target, out);
119            scan_expr_for_qa(&a.value, out);
120        }
121        Stmt::Expr(e) => scan_expr_for_qa(e, out),
122        Stmt::If(i) => {
123            scan_expr_for_qa(&i.cond, out);
124            scan_stmt_for_qa(&i.then_branch, out);
125            if let Some(e) = &i.else_branch {
126                scan_stmt_for_qa(e, out);
127            }
128        }
129        Stmt::While(w) => {
130            scan_expr_for_qa(&w.cond, out);
131            scan_stmt_for_qa(&w.body, out);
132        }
133        Stmt::For(f) => {
134            if let Some(init) = &f.init {
135                scan_stmt_for_qa(init, out);
136            }
137            if let Some(c) = &f.cond {
138                scan_expr_for_qa(c, out);
139            }
140            if let Some(st) = &f.step {
141                scan_expr_for_qa(st, out);
142            }
143            scan_stmt_for_qa(&f.body, out);
144        }
145        Stmt::Return(Some(e)) => scan_expr_for_qa(e, out),
146        Stmt::Block(b) => scan_block_for_qa(b, out),
147        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
148    }
149}
150
151fn scan_expr_for_qa(e: &Expr, out: &mut QaUses) {
152    match e {
153        Expr::Ident(name, _) => {
154            if let Some(idx) = qa_letter_index(name) {
155                out.referenced_mask |= 1u8 << idx;
156            }
157        }
158        Expr::Lit(_) => {}
159        Expr::Binary(b) => {
160            scan_expr_for_qa(&b.lhs, out);
161            scan_expr_for_qa(&b.rhs, out);
162        }
163        Expr::Unary(u) => scan_expr_for_qa(&u.operand, out),
164        Expr::Ternary(t) => {
165            scan_expr_for_qa(&t.cond, out);
166            scan_expr_for_qa(&t.then_expr, out);
167            scan_expr_for_qa(&t.else_expr, out);
168        }
169        Expr::Call(c) => {
170            if c.callee == "float2x2"
171                && c.args.len() == 1
172                && let Expr::Ident(n, _) = &c.args[0]
173                && let Some(idx) = qa_letter_index(n)
174            {
175                let letter = QA_LETTERS[idx];
176                out.mat_call_spans.push((c.span, letter));
177                out.referenced_mask |= 1u8 << idx;
178                return;
179            }
180            for a in &c.args {
181                scan_expr_for_qa(a, out);
182            }
183        }
184        Expr::Member(m) => scan_expr_for_qa(&m.base, out),
185        Expr::Swizzle(s) => scan_expr_for_qa(&s.base, out),
186        Expr::Index(i) => {
187            scan_expr_for_qa(&i.base, out);
188            scan_expr_for_qa(&i.index, out);
189        }
190        Expr::InitList(l) => {
191            for e in &l.elems {
192                scan_expr_for_qa(e, out);
193            }
194        }
195        Expr::Assign(a) => {
196            scan_expr_for_qa(&a.target, out);
197            scan_expr_for_qa(&a.value, out);
198        }
199    }
200}
201
202/// Locate the byte position immediately after the first `{` at or past
203/// `start`. `shader_body { … }`'s parser hands us a span starting at the
204/// keyword, not the brace, so this helper bridges the gap. Returns `None`
205/// when no `{` is found (malformed source — we skip injection in that case).
206fn find_brace_after(src: &str, start: u32) -> Option<u32> {
207    let bytes = src.as_bytes();
208    let mut i = start as usize;
209    while i < bytes.len() && bytes[i] != b'{' {
210        i += 1;
211    }
212    if i < bytes.len() {
213        Some((i + 1) as u32)
214    } else {
215        None
216    }
217}