onedrop_hlsl/rewrite/
chained_init.rs

1//! Pass: lower chained assignment in local-decl init position.
2//!
3//! HLSL accepts `T x = y = expr;` (and deeper chains like
4//! `T x = y = z = expr;`), evaluating right-to-left. WGSL has no
5//! assignment expression — `var x: T = (y = expr);` is rejected at parse.
6//! The corpus shape that hits this is a Martin/Rovastar idiom found in
7//! ~8 / 2 000 random presets:
8//!
9//! ```hlsl
10//! float2 ruv = uv = 0.5 + (uv-0.5)*(1+(rc.y*0.05));
11//! ```
12//!
13//! Rewrite: lift the inner assignments out as statements that run
14//! *before* the decl, then initialise the decl from the leftmost LHS:
15//!
16//! ```hlsl
17//! uv = 0.5 + (uv-0.5)*(1+(rc.y*0.05)); float2 ruv = uv;
18//! ```
19//!
20//! Detection runs on the AST produced by [`crate::parse::parse_hlsl`].
21//! The parser already lowers the chain into a sequence of synthetic
22//! [`Stmt::Assign`]s followed by the [`Stmt::LocalDecl`]; here we just
23//! recognise that the synthetic assigns' spans sit *after* the decl's
24//! span in source order (whereas a hand-written
25//! `uv = 0.5; float2 ruv = uv;` has the assign's span *before* the
26//! decl's). One edit replaces the original chained source with the
27//! lowered statements; the existing trailing `;` is preserved.
28
29use super::*;
30
31pub(crate) fn rewrite_chained_assign_inits(src: &str) -> String {
32    let Ok(tu) = parse_hlsl(src) else {
33        return src.to_string();
34    };
35    let mut edits = Vec::new();
36    if let Some(body) = &tu.shader_body {
37        walk_block(body, src, &mut edits);
38    }
39    for item in &tu.items {
40        if let Item::Function(f) = item {
41            walk_block(&f.body, src, &mut edits);
42        }
43    }
44    apply_edits(src, &mut edits)
45}
46
47fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
48    let stmts = &b.stmts;
49    let mut i = 0;
50    while i < stmts.len() {
51        if let Stmt::LocalDecl(d) = &stmts[i] {
52            let chain_len = count_preceding_synthetic(stmts, i, d.span.start);
53            if chain_len > 0 {
54                emit_chain_edit(d, &stmts[i - chain_len..i], src, edits);
55            }
56        }
57        descend(&stmts[i], src, edits);
58        i += 1;
59    }
60}
61
62fn descend(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
63    match s {
64        Stmt::If(i) => {
65            descend(&i.then_branch, src, edits);
66            if let Some(e) = &i.else_branch {
67                descend(e, src, edits);
68            }
69        }
70        Stmt::While(w) => descend(&w.body, src, edits),
71        Stmt::For(f) => descend(&f.body, src, edits),
72        Stmt::Block(b) => walk_block(b, src, edits),
73        _ => {}
74    }
75}
76
77/// How many `Stmt::Assign`s immediately before index `i` are synthetic
78/// chained-init lowerings — identified by a span that starts *after* the
79/// LocalDecl's span starts (i.e. the assignment text sits inside the
80/// chained `T x = y = …;` block, but the AST lifted it ahead of the decl).
81fn count_preceding_synthetic(stmts: &[Stmt], i: usize, decl_start: u32) -> usize {
82    let mut k = 0;
83    while k < i {
84        let prev = &stmts[i - k - 1];
85        let Stmt::Assign(a) = prev else { break };
86        if a.span.start <= decl_start {
87            break;
88        }
89        k += 1;
90    }
91    k
92}
93
94fn emit_chain_edit(d: &LocalDecl, synth: &[Stmt], src: &str, edits: &mut Vec<TextEdit>) {
95    // synth is in AST order: synth[0] is the innermost assign (RHS = the
96    // user's final expr); synth[len-1] is the outermost (RHS = the next
97    // chain link). The source text for the whole chained init spans from
98    // the decl's start to the innermost assign's value end.
99    let last_value_end = match &synth[0] {
100        Stmt::Assign(a) => a.value.span().end,
101        _ => return,
102    };
103    let start = d.span.start;
104    if last_value_end <= start {
105        return; // sanity: spans must point forward
106    }
107
108    let mut text = String::new();
109    // Emit assignments in evaluation order (innermost first).
110    for stmt in synth {
111        let Stmt::Assign(a) = stmt else {
112            return;
113        };
114        text.push_str(slice(src, a.target.span()));
115        text.push_str(" = ");
116        text.push_str(slice(src, a.value.span()));
117        text.push_str("; ");
118    }
119    // Emit the decl: original "T name" text, then `= <leftmost LHS>`.
120    text.push_str(slice(src, d.span));
121    if let Some(init) = &d.init {
122        text.push_str(" = ");
123        text.push_str(slice(src, init.span()));
124    }
125
126    edits.push(TextEdit {
127        start,
128        end: last_value_end,
129        replacement: text,
130    });
131}
132
133fn slice(src: &str, sp: Span) -> &str {
134    &src[sp.start as usize..sp.end as usize]
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn two_deep_chain_lowered() {
143        let src = "shader_body { float2 ruv = uv = 0.5; }";
144        let out = rewrite_chained_assign_inits(src);
145        assert!(out.contains("uv = 0.5; float2 ruv = uv"), "got: {out}");
146        // Trailing `;` from the original source is preserved.
147        assert!(out.trim_end_matches('}').trim_end().ends_with(';'));
148    }
149
150    #[test]
151    fn three_deep_chain_lowered() {
152        let src = "shader_body { float2 ruv = uv = uv2 = 0.5; }";
153        let out = rewrite_chained_assign_inits(src);
154        assert!(
155            out.contains("uv2 = 0.5; uv = uv2; float2 ruv = uv"),
156            "got: {out}"
157        );
158    }
159
160    #[test]
161    fn hand_written_assign_then_decl_untouched() {
162        let src = "shader_body { uv = 0.5; float2 ruv = uv; }";
163        let out = rewrite_chained_assign_inits(src);
164        assert_eq!(out, src);
165    }
166
167    #[test]
168    fn realistic_corpus_shape() {
169        // The exact idiom from MilkDrop2077.11337.milk and 7 sibling
170        // presets in the seed-42 2000-sample.
171        let src = "shader_body { float3 rc = GetBlur1(uv); \
172                   float2 ruv = uv = 0.5 + (uv-0.5)*(1+(rc.y*0.05)); \
173                   ret = tex2D(sampler_main, ruv).xyz; }";
174        let out = rewrite_chained_assign_inits(src);
175        assert!(
176            out.contains("uv = 0.5 + (uv-0.5)*(1+(rc.y*0.05)); float2 ruv = uv"),
177            "got: {out}"
178        );
179        // The non-chained statements around it must survive untouched.
180        assert!(out.contains("float3 rc = GetBlur1(uv);"));
181        assert!(out.contains("ret = tex2D(sampler_main, ruv).xyz;"));
182    }
183}