onedrop_hlsl/rewrite/
embedded_assign.rs

1//! Pass: lift assignment-as-expression out of mid-statement positions.
2//!
3//! HLSL accepts `lerp(a, tmp = GetBlur1(uv), b)` — the inner `tmp = …`
4//! assigns *and* yields the new value as the call arg. WGSL has no
5//! assignment expression and rejects the source at parse time with
6//! "expected `)` after call arguments; found Eq" (when our HLSL parser
7//! is the one rejecting) or with `invalid type for binary operator` /
8//! similar (when naga sees `lerp(a, =, b)`).
9//!
10//! Rewrite shape: for each enclosing statement that contains one or more
11//! embedded [`Expr::Assign`] sub-expressions, emit
12//!
13//! 1. an insertion at the statement's start with the assignment lifted to
14//!    a standalone `<target> = <value>;` (preserving the HLSL side effect),
15//! 2. a replacement at the assignment's position with just the target text
16//!    (so the surrounding call's arg count is preserved).
17//!
18//! Only fires when the parser produces an `Expr::Assign` outside the
19//! top-level `Stmt::Assign` slot — i.e. nested inside a `Stmt::Return`
20//! value, an `Stmt::LocalDecl` init, an `Stmt::Expr` (typically a bare
21//! call), or another `Expr::Call`'s arg list. Direct `Stmt::Assign`s are
22//! already top-level and don't need lifting.
23//!
24//! Limitations:
25//! - Nested `(x = y = z)` chains in a call arg lift only the outermost
26//!   level; the inner chain would re-enter [`parse_expr`] without going
27//!   through this pass. Rare in the corpus.
28//! - The lifted assignment runs at the *start* of the enclosing statement,
29//!   which is the HLSL evaluation order for the single-assignment case
30//!   (call args evaluate left-to-right; the assigning arg's side effect
31//!   happens before the call). For multiple assigns in the same statement
32//!   we emit them in source order, which still matches left-to-right.
33
34use super::*;
35
36pub(crate) fn rewrite_embedded_assigns(src: &str) -> String {
37    let Ok(tu) = parse_hlsl(src) else {
38        return src.to_string();
39    };
40    let mut edits = Vec::new();
41    if let Some(body) = &tu.shader_body {
42        walk_block(body, src, &mut edits);
43    }
44    for item in &tu.items {
45        if let Item::Function(f) = item {
46            walk_block(&f.body, src, &mut edits);
47        }
48    }
49    apply_edits(src, &mut edits)
50}
51
52fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
53    for s in &b.stmts {
54        walk_stmt(s, src, edits);
55    }
56}
57
58fn walk_stmt(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
59    // Collect every embedded assign reachable from this statement's own
60    // expressions (not counting the top-level Stmt::Assign — that case
61    // is fine as-is). Then emit one prelude insertion at the statement's
62    // start and one replacement per embedded assign.
63    let stmt_start = match stmt_span_start(s, src) {
64        Some(p) => p,
65        None => return, // can't lift safely without an anchor
66    };
67    let mut assigns: Vec<&AssignExpr> = Vec::new();
68    match s {
69        Stmt::LocalDecl(d) => {
70            if let Some(init) = &d.init {
71                collect_embedded(init, &mut assigns);
72            }
73            if let Some(arr) = &d.array_len {
74                collect_embedded(arr, &mut assigns);
75            }
76        }
77        Stmt::Assign(a) => {
78            // `target` is the LHS — if a swizzle or index, its sub-expressions
79            // may contain embedded assigns (rare).
80            collect_embedded(&a.target, &mut assigns);
81            collect_embedded(&a.value, &mut assigns);
82        }
83        Stmt::Expr(e) => collect_embedded(e, &mut assigns),
84        Stmt::Return(Some(e)) => collect_embedded(e, &mut assigns),
85        Stmt::If(i) => {
86            collect_embedded(&i.cond, &mut assigns);
87            walk_stmt(&i.then_branch, src, edits);
88            if let Some(e) = &i.else_branch {
89                walk_stmt(e, src, edits);
90            }
91        }
92        Stmt::While(w) => {
93            collect_embedded(&w.cond, &mut assigns);
94            walk_stmt(&w.body, src, edits);
95        }
96        Stmt::For(f) => {
97            // `for` init / step run their own assignment semantics — the
98            // existing parser captures `i = i + 1` as Expr::Assign for the
99            // step; that's the legitimate use. We don't lift those out.
100            if let Some(cond) = &f.cond {
101                collect_embedded(cond, &mut assigns);
102            }
103            walk_stmt(&f.body, src, edits);
104            if let Some(init) = &f.init {
105                walk_stmt(init, src, edits);
106            }
107        }
108        Stmt::Block(b) => walk_block(b, src, edits),
109        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
110    }
111    if assigns.is_empty() {
112        return;
113    }
114    // Sort source-order so prelude lines match left-to-right evaluation.
115    assigns.sort_by_key(|a| a.span.start);
116    let mut prelude = String::new();
117    for a in &assigns {
118        prelude.push_str(slice(src, a.target.span()));
119        prelude.push(' ');
120        prelude.push_str(assign_op_str(a.op));
121        prelude.push(' ');
122        prelude.push_str(slice(src, a.value.span()));
123        prelude.push_str("; ");
124        // Replace the assign-expr in place with just the target identifier.
125        edits.push(TextEdit {
126            start: a.span.start,
127            end: a.span.end,
128            replacement: slice(src, a.target.span()).to_string(),
129        });
130    }
131    edits.push(TextEdit {
132        start: stmt_start,
133        end: stmt_start,
134        replacement: prelude,
135    });
136}
137
138fn assign_op_str(op: AssignOp) -> &'static str {
139    match op {
140        AssignOp::Set => "=",
141        AssignOp::Add => "+=",
142        AssignOp::Sub => "-=",
143        AssignOp::Mul => "*=",
144        AssignOp::Div => "/=",
145        AssignOp::Rem => "%=",
146    }
147}
148
149fn stmt_span_start(s: &Stmt, src: &str) -> Option<u32> {
150    match s {
151        Stmt::LocalDecl(d) => Some(d.span.start),
152        Stmt::Assign(a) => Some(a.span.start),
153        Stmt::Expr(e) => Some(e.span().start),
154        // [`Stmt::Return`] doesn't carry an explicit span; the AST only
155        // remembers the inner expression. Walk back from the expr's start
156        // to locate the `return` keyword that precedes it (skipping a run
157        // of whitespace).
158        Stmt::Return(Some(e)) => find_return_keyword_before(src, e.span().start),
159        Stmt::Return(None) => None,
160        Stmt::If(i) => Some(i.span.start),
161        Stmt::While(w) => Some(w.span.start),
162        Stmt::For(f) => Some(f.span.start),
163        Stmt::Block(b) => Some(b.span.start),
164        Stmt::Break | Stmt::Continue => None,
165    }
166}
167
168const RETURN_KW: &[u8] = b"return";
169
170/// Walk back from byte position `expr_start` over whitespace, then check
171/// for the `return` keyword. Returns the keyword's start position when
172/// found, `None` otherwise (which keeps the caller from emitting a
173/// misplaced prelude).
174fn find_return_keyword_before(src: &str, expr_start: u32) -> Option<u32> {
175    let bytes = src.as_bytes();
176    let mut i = expr_start as usize;
177    while i > 0 && bytes[i - 1].is_ascii_whitespace() {
178        i -= 1;
179    }
180    if i < RETURN_KW.len() {
181        return None;
182    }
183    let kw_start = i - RETURN_KW.len();
184    if &bytes[kw_start..i] != RETURN_KW {
185        return None;
186    }
187    // Left boundary: previous char must not be identifier-continuation.
188    if kw_start > 0 {
189        let p = bytes[kw_start - 1];
190        if p.is_ascii_alphanumeric() || p == b'_' {
191            return None;
192        }
193    }
194    Some(kw_start as u32)
195}
196
197fn collect_embedded<'e>(e: &'e Expr, out: &mut Vec<&'e AssignExpr>) {
198    match e {
199        Expr::Assign(a) => {
200            // Capture this assignment AND descend into its RHS so a chain
201            // like `lerp(a, tmp = (rhs = X), b)` would lift both (rare).
202            out.push(a);
203            collect_embedded(&a.value, out);
204        }
205        Expr::Binary(b) => {
206            collect_embedded(&b.lhs, out);
207            collect_embedded(&b.rhs, out);
208        }
209        Expr::Unary(u) => collect_embedded(&u.operand, out),
210        Expr::Ternary(t) => {
211            collect_embedded(&t.cond, out);
212            collect_embedded(&t.then_expr, out);
213            collect_embedded(&t.else_expr, out);
214        }
215        Expr::Call(c) => {
216            for a in &c.args {
217                collect_embedded(a, out);
218            }
219        }
220        Expr::Member(m) => collect_embedded(&m.base, out),
221        Expr::Swizzle(s) => collect_embedded(&s.base, out),
222        Expr::Index(i) => {
223            collect_embedded(&i.base, out);
224            collect_embedded(&i.index, out);
225        }
226        Expr::InitList(l) => {
227            for e in &l.elems {
228                collect_embedded(e, out);
229            }
230        }
231        Expr::Lit(_) | Expr::Ident(_, _) => {}
232    }
233}
234
235fn slice(src: &str, sp: Span) -> &str {
236    &src[sp.start as usize..sp.end as usize]
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn lift_single_assign_in_call_arg() {
245        // The Whoah corpus shape: `tmp = GetBlur1(uvi)` as a `lerp` arg.
246        let src = "shader_body { ret = lerp(a, tmp = b, c); }";
247        let out = rewrite_embedded_assigns(src);
248        // Prelude assignment inserted before the statement.
249        assert!(
250            out.contains("tmp = b; ret = lerp(a, tmp, c);"),
251            "got: {out}"
252        );
253    }
254
255    #[test]
256    fn lift_inside_return() {
257        let src = "float Get(float2 uvi) { float tmp; return lerp(a, tmp = b, c); } \
258             shader_body { ret = float3(Get(uv), 0); }";
259        let out = rewrite_embedded_assigns(src);
260        assert!(
261            out.contains("tmp = b; return lerp(a, tmp, c);"),
262            "got: {out}"
263        );
264    }
265
266    #[test]
267    fn no_change_when_no_embedded_assigns() {
268        let src = "shader_body { ret = lerp(a, b, c); }";
269        let out = rewrite_embedded_assigns(src);
270        assert_eq!(out, src);
271    }
272
273    #[test]
274    fn full_whoah_warp_translates_without_embedded_assign_in_output() {
275        // End-to-end check: feed the full HLSL through translate_shader
276        // and confirm `tmp = GetBlur1(uvi)` doesn't survive as an embedded
277        // assignment expression (which would die at naga parse).
278        let src = include_str!("../../tests/whoah_warp.hlsl");
279        let wgsl = crate::translate_shader(src).expect("translate ok");
280        // The assignment should only appear in standalone-statement form
281        // (`tmp = GetBlur1(uvi);`), never as a call-arg expression.
282        let bad = "tmp = GetBlur1(uvi),";
283        assert!(
284            !wgsl.contains(bad),
285            "embedded assign survived translate:\n{wgsl}"
286        );
287    }
288
289    #[test]
290    fn full_whoah_warp_parses_and_rewrites() {
291        // Full warp shader text extracted from the failing preset. Locks
292        // both the parse acceptance and the lifting outcome.
293        let src = include_str!("../../tests/whoah_warp.hlsl");
294        let parsed = crate::parse::parse_hlsl(src);
295        assert!(parsed.is_ok(), "parse err: {:?}", parsed.err());
296        let out = rewrite_embedded_assigns(src);
297        // The lifted line must appear as a standalone assignment in `Get1`.
298        assert!(
299            out.contains("tmp = GetBlur1(uvi); return lerp"),
300            "lift did not fire on Get1:\n{out}"
301        );
302    }
303
304    #[test]
305    fn corpus_shape_survives_full_translate_pipeline() {
306        // End-to-end: feed the same source through translate_shader and
307        // confirm the lifted `tmp = GetBlur1(uvi)` doesn't survive into
308        // WGSL output (which would die at naga parse with `expected )`).
309        let src = "float3 Get1 (float2 uvi) {float3 tmp; float2 pix; \
310                   return lerp (GetPixel(uvi), tmp = GetBlur1(uvi),change*4);} \
311                   shader_body { ret = Get1(uv); }";
312        let wgsl = crate::translate_shader(src).unwrap();
313        assert!(
314            !wgsl.contains("tmp = GetBlur1(uvi)") || wgsl.contains("tmp = GetBlur1(uvi);"),
315            "raw assign-as-arg leaked into WGSL: {wgsl}"
316        );
317    }
318
319    #[test]
320    fn exact_corpus_shape_whoah_get1() {
321        // Reproduces the failing line in
322        // `whoah dj g - CLOUD APPS 022 - Copy.milk` verbatim — `lerp` with
323        // a space before `(` and `tmp = GetBlur1(uvi)` as the middle arg.
324        let src = "float3 Get1 (float2 uvi) {float3 tmp; float2 pix; \
325                   return lerp (GetPixel(uvi), tmp = GetBlur1(uvi),change*4);} \
326                   shader_body { ret = Get1(uv); }";
327        let out = rewrite_embedded_assigns(src);
328        assert!(
329            out.contains("tmp = GetBlur1(uvi); return lerp (GetPixel(uvi), tmp,change*4);"),
330            "got: {out}"
331        );
332    }
333
334    #[test]
335    fn top_level_assign_not_lifted() {
336        // Stmt::Assign is already a top-level assignment — leave it alone.
337        let src = "shader_body { tmp = b; ret = tmp; }";
338        let out = rewrite_embedded_assigns(src);
339        assert_eq!(out, src);
340    }
341}