onedrop_hlsl/rewrite/
swizzle_assign.rs

1//! Pass: lower multi-component swizzle assignments to per-component writes.
2//!
3//! HLSL accepts `vec.xy = expr` / `vec.xy += expr` / `color.rgb = …`
4//! verbatim. WGSL refuses with `invalid left-hand side of assignment` and
5//! the hint *"WGSL does not support assignments to swizzles; consider
6//! assigning each component individually"*. This is the dominant warp-side
7//! failure on the in-the-wild corpus (≈ 100 cases on the 2 000-sample).
8//!
9//! Rewrite shape: `lhs.SWIZ op= rhs` → `floatN _swztmp_K = (rhs); lhs.c0 op= _swztmp_K.x; lhs.c1 op= _swztmp_K.y; …`
10//! where `K` is a per-pass counter so nested or chained rewrites don't
11//! collide. We emit HLSL-shaped output (`floatN`) so the regex pipeline
12//! can lower it to `vec<f32>` like any other declaration. Single-component
13//! swizzle assigns (`uv.x = …`) are left alone — WGSL accepts them.
14//!
15//! We deliberately skip:
16//! - swizzles with repeated components (`uv.xx = …`) — undefined in HLSL
17//!   too, surface the WGSL error instead of synthesising a fake last-write.
18//! - assignments that sit inside a comma-chain (`a.xy = …, b += 1;`).
19//!   The parser splits these into separate `Stmt::Assign`s but the
20//!   intervening `,` would dangle if we expand one side into multiple
21//!   statements. Conservative: leave the chain alone (rare in the corpus).
22
23use super::*;
24
25pub(crate) fn rewrite_swizzle_assigns(src: &str) -> String {
26    let Ok(tu) = parse_hlsl(src) else {
27        return src.to_string();
28    };
29    let mut state = RewriteState {
30        src,
31        edits: Vec::new(),
32        counter: 0,
33    };
34    if let Some(body) = &tu.shader_body {
35        walk_block(body, &mut state);
36    }
37    for item in &tu.items {
38        if let Item::Function(f) = item {
39            walk_block(&f.body, &mut state);
40        }
41    }
42    apply_edits(src, &mut state.edits)
43}
44
45struct RewriteState<'a> {
46    src: &'a str,
47    edits: Vec<TextEdit>,
48    counter: usize,
49}
50
51fn walk_block(b: &Block, st: &mut RewriteState) {
52    for s in &b.stmts {
53        walk_stmt(s, st);
54    }
55}
56
57fn walk_stmt(s: &Stmt, st: &mut RewriteState) {
58    match s {
59        Stmt::Assign(a) => try_emit(a, st),
60        Stmt::If(i) => {
61            walk_stmt(&i.then_branch, st);
62            if let Some(e) = &i.else_branch {
63                walk_stmt(e, st);
64            }
65        }
66        Stmt::While(w) => walk_stmt(&w.body, st),
67        Stmt::For(f) => {
68            if let Some(init) = &f.init {
69                walk_stmt(init, st);
70            }
71            walk_stmt(&f.body, st);
72        }
73        Stmt::Block(b) => walk_block(b, st),
74        Stmt::LocalDecl(_) | Stmt::Expr(_) | Stmt::Return(_) | Stmt::Break | Stmt::Continue => {}
75    }
76}
77
78fn try_emit(a: &AssignStmt, st: &mut RewriteState) {
79    let Expr::Swizzle(sw) = &a.target else {
80        return;
81    };
82    let comps = sw.components.as_str();
83    if comps.len() < 2 {
84        return; // single-component swizzle assign is valid WGSL
85    }
86    if !components_unique(comps) {
87        return; // duplicate writes — leave WGSL to surface the malformed source
88    }
89    if followed_by_comma(st.src, a.span.end) {
90        return; // part of a comma-chain — rewrite would dangle the comma
91    }
92
93    let n = comps.len();
94    let temp_ty = match n {
95        2 => "float2",
96        3 => "float3",
97        4 => "float4",
98        _ => return,
99    };
100    let id = st.counter;
101    st.counter += 1;
102    let tmp = format!("_swztmp_{id}");
103    let op_str = assign_op_str(a.op);
104
105    // Slice the original LHS base text + RHS text so any inner edits emitted
106    // by other passes (binop_vec, vec_cmp, …) on the same source still apply
107    // — we only own the wrapper text around them.
108    let base_text = slice(st.src, sw.base.span());
109    let rhs_text = slice(st.src, a.value.span());
110
111    let mut out = String::new();
112    out.push_str(temp_ty);
113    out.push(' ');
114    out.push_str(&tmp);
115    out.push_str(" = (");
116    out.push_str(rhs_text);
117    out.push(')');
118    for (i, c) in comps.chars().enumerate() {
119        let lhs_comp = normalise_component(c);
120        let rhs_comp = match i {
121            0 => 'x',
122            1 => 'y',
123            2 => 'z',
124            3 => 'w',
125            _ => unreachable!(),
126        };
127        out.push_str("; ");
128        out.push_str(base_text);
129        out.push('.');
130        out.push(lhs_comp);
131        out.push_str(op_str);
132        out.push_str(&tmp);
133        out.push('.');
134        out.push(rhs_comp);
135    }
136
137    st.edits.push(TextEdit {
138        start: a.span.start,
139        end: a.span.end,
140        replacement: out,
141    });
142}
143
144fn components_unique(s: &str) -> bool {
145    let mut seen = [false; 8]; // indexed by normalised xyzw -> 0..3
146    for c in s.chars() {
147        let idx = match normalise_component(c) {
148            'x' => 0,
149            'y' => 1,
150            'z' => 2,
151            'w' => 3,
152            _ => return false,
153        };
154        if seen[idx] {
155            return false;
156        }
157        seen[idx] = true;
158    }
159    true
160}
161
162fn normalise_component(c: char) -> char {
163    match c {
164        'r' => 'x',
165        'g' => 'y',
166        'b' => 'z',
167        'a' => 'w',
168        c => c,
169    }
170}
171
172fn assign_op_str(op: AssignOp) -> &'static str {
173    match op {
174        AssignOp::Set => " = ",
175        AssignOp::Add => " += ",
176        AssignOp::Sub => " -= ",
177        AssignOp::Mul => " *= ",
178        AssignOp::Div => " /= ",
179        AssignOp::Rem => " %= ",
180    }
181}
182
183fn followed_by_comma(src: &str, end: u32) -> bool {
184    let bytes = src.as_bytes();
185    let mut i = end as usize;
186    while i < bytes.len() {
187        match bytes[i] {
188            b' ' | b'\t' | b'\r' | b'\n' => i += 1,
189            b',' => return true,
190            _ => return false,
191        }
192    }
193    false
194}
195
196fn slice(src: &str, sp: Span) -> &str {
197    &src[sp.start as usize..sp.end as usize]
198}