onedrop_hlsl/rewrite/
array_lower.rs

1//! Pass: HLSL array → WGSL `array<T, N>` lowering (globals + locals).
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass 2: array global lowering
7// ---------------------------------------------------------------------------
8
9/// Lower MD2-style global arrays (`const float4 samples[5] = {…};`) to a
10/// WGSL-friendly comment-block form that the regex pipeline can survive.
11/// The original `const` line confuses `replace_types` (which doesn't know
12/// about `[N]`) and `rewrite_local_declarations` (which doesn't accept
13/// init lists). We replace it with a flat `var<private>` decl in
14/// HLSL-shaped sugar plus a strip marker; the regex pass then leaves it
15/// alone because the syntactic shape is already WGSL-friendly.
16///
17/// Targets the `const float4 samples[5] = {…};` idiom (cluster K + H).
18/// Falls back to passing the source through when parsing fails or no
19/// matching globals are found.
20#[cfg(test)]
21pub(crate) fn lower_array_globals(src: &str) -> String {
22    let Ok(tu) = parse_hlsl(src) else {
23        return src.to_string();
24    };
25    let mut edits = Vec::new();
26    collect_global_array_edits(&tu, src, &mut edits);
27    apply_edits(src, &mut edits)
28}
29
30/// Walk every top-level `Item::GlobalVar` with an array length + InitList
31/// and emit one text edit per match. Shared between [`lower_array_globals`]
32/// (the standalone pass) and [`apply_lowerings`] (the combined pass used
33/// by [`apply_all`]).
34pub(crate) fn collect_global_array_edits(
35    tu: &TranslationUnit,
36    src: &str,
37    edits: &mut Vec<TextEdit>,
38) {
39    for item in &tu.items {
40        let Item::GlobalVar(g) = item else { continue };
41        if g.array_len.is_none() {
42            continue;
43        }
44        let Some(init) = &g.init else {
45            continue;
46        };
47        let Expr::InitList(list) = init else {
48            continue;
49        };
50        let comp_ty = type_from_typeref(&g.ty);
51        let comp_size = vec_size(comp_ty);
52        if comp_size == 0 {
53            continue;
54        }
55        let Some(array_len_expr) = &g.array_len else {
56            continue;
57        };
58        let Expr::Lit(Lit {
59            value: LitValue::Int(n),
60            ..
61        }) = array_len_expr
62        else {
63            continue;
64        };
65        let n = *n as usize;
66        let comp_wgsl = comp_ty.wgsl_name();
67        let mut elems: Vec<String> = Vec::new();
68        // The HLSL flat init `{ 0.0, 1.0, ... }` (component-by-component)
69        // needs grouping into vecN. We also accept the already-vecN form
70        // `{ float4(...), ... }`.
71        if list.elems.len() == n {
72            // Each element is one vector — emit as-is in WGSL constructor
73            // shape.
74            for e in &list.elems {
75                elems.push(emit_expr_as_wgsl(e, src, comp_ty));
76            }
77        } else if list.elems.len() == n * comp_size {
78            // Flat list: pack `comp_size` scalars per element.
79            for chunk in list.elems.chunks(comp_size) {
80                let parts: Vec<String> = chunk
81                    .iter()
82                    .map(|e| emit_expr_as_wgsl(e, src, WgslType::F32))
83                    .collect();
84                elems.push(format!("{comp_wgsl}({})", parts.join(", ")));
85            }
86        } else {
87            continue; // shape mismatch, leave it alone
88        }
89        let body = elems.join(", ");
90        // Emit as WGSL `var`-initialised array. `const` would be cleaner
91        // but requires every element to be const-evaluable; `var` survives
92        // arithmetic like `11.0/3.0` and matches the address-space-free
93        // form that's legal at both module scope and inside `fs_main`
94        // (the existing pipeline drops the const-array body into the
95        // function body, so we can't assume module scope).
96        let replacement = format!(
97            "var {name}: array<{comp_wgsl}, {n}> = array<{comp_wgsl}, {n}>({body});",
98            name = g.name,
99        );
100        edits.push(TextEdit {
101            start: g.span.start,
102            end: span_includes_semi(src, g.span).end,
103            replacement,
104        });
105    }
106}
107
108/// Lower local array declarations (`float3 m[3];`, `float arr[5] = {…};`)
109/// to WGSL-valid `var` shape. The downstream regex `rewrite_local_declarations`
110/// doesn't understand the `[N]` suffix or init-lists, so the lowered form
111/// drops through to naga as-is. Targets the ougiel local-array failures
112/// and any preset with bare-typed local arrays.
113#[cfg(test)]
114pub(crate) fn lower_local_arrays(src: &str) -> String {
115    let Ok(tu) = parse_hlsl(src) else {
116        return src.to_string();
117    };
118    let mut edits = Vec::new();
119    if let Some(body) = &tu.shader_body {
120        collect_local_array_edits(body, src, &mut edits);
121    }
122    for item in &tu.items {
123        if let Item::Function(f) = item {
124            collect_local_array_edits(&f.body, src, &mut edits);
125        }
126    }
127    apply_edits(src, &mut edits)
128}
129
130pub(crate) fn collect_local_array_edits(block: &Block, src: &str, edits: &mut Vec<TextEdit>) {
131    for s in &block.stmts {
132        match s {
133            Stmt::LocalDecl(d) if d.array_len.is_some() => {
134                let Some(len_expr) = &d.array_len else {
135                    continue;
136                };
137                let Expr::Lit(Lit {
138                    value: LitValue::Int(n),
139                    ..
140                }) = len_expr
141                else {
142                    continue;
143                };
144                let comp_ty = type_from_typeref(&d.ty);
145                if comp_ty == WgslType::Unknown {
146                    continue;
147                }
148                let comp_wgsl = comp_ty.wgsl_name();
149                let replacement = match &d.init {
150                    None => format!("var {name}: array<{comp_wgsl}, {n}>;", name = d.name),
151                    Some(Expr::InitList(l)) => {
152                        // Scalar element types (`float arr[N]`) treat as
153                        // 1 component per element so the flat-init branch
154                        // works without a special case.
155                        let comp_size = if comp_ty.is_scalar() {
156                            1
157                        } else {
158                            vec_size(comp_ty)
159                        };
160                        if comp_size == 0 {
161                            continue;
162                        }
163                        let n_usize = *n as usize;
164                        let elems: Vec<String> = if comp_ty.is_scalar() && l.elems.len() == n_usize
165                        {
166                            l.elems
167                                .iter()
168                                .map(|e| emit_expr_as_wgsl(e, src, comp_ty))
169                                .collect()
170                        } else if l.elems.len() == n_usize {
171                            // One vector per array element.
172                            l.elems
173                                .iter()
174                                .map(|e| emit_expr_as_wgsl(e, src, comp_ty))
175                                .collect()
176                        } else if l.elems.len() == n_usize * comp_size {
177                            l.elems
178                                .chunks(comp_size)
179                                .map(|chunk| {
180                                    let parts: Vec<String> = chunk
181                                        .iter()
182                                        .map(|e| emit_expr_as_wgsl(e, src, WgslType::F32))
183                                        .collect();
184                                    format!("{comp_wgsl}({})", parts.join(", "))
185                                })
186                                .collect()
187                        } else {
188                            continue;
189                        };
190                        format!(
191                            "var {name}: array<{comp_wgsl}, {n}> = array<{comp_wgsl}, {n}>({});",
192                            elems.join(", "),
193                            name = d.name,
194                        )
195                    }
196                    Some(_) => continue, // unusual init shape, leave alone
197                };
198                edits.push(TextEdit {
199                    start: d.span.start,
200                    end: span_includes_semi(src, d.span).end,
201                    replacement,
202                });
203            }
204            Stmt::If(i) => {
205                if let Stmt::Block(b) = &*i.then_branch {
206                    collect_local_array_edits(b, src, edits);
207                }
208                if let Some(eb) = &i.else_branch
209                    && let Stmt::Block(b) = &**eb
210                {
211                    collect_local_array_edits(b, src, edits);
212                }
213            }
214            Stmt::While(w) => {
215                if let Stmt::Block(b) = &*w.body {
216                    collect_local_array_edits(b, src, edits);
217                }
218            }
219            Stmt::For(f) => {
220                if let Stmt::Block(b) = &*f.body {
221                    collect_local_array_edits(b, src, edits);
222                }
223            }
224            Stmt::Block(b) => collect_local_array_edits(b, src, edits),
225            _ => {}
226        }
227    }
228}
229
230/// Extend a span forward to include the trailing `;` (if present). The
231/// AST's GlobalVar span ends at the name, not the semicolon — for textual
232/// replacement we want to swallow the whole declaration.
233fn span_includes_semi(src: &str, span: Span) -> Span {
234    let bytes = src.as_bytes();
235    let mut i = span.end as usize;
236    while i < bytes.len() && bytes[i] != b';' {
237        i += 1;
238    }
239    if i < bytes.len() {
240        i += 1; // include the `;`
241    }
242    Span {
243        end: i as u32,
244        ..span
245    }
246}
247
248/// Emit an HLSL expression as WGSL text — narrow subset, enough for the
249/// array-global lowering case where elements are literals or simple
250/// arithmetic. Falls back to the raw original text if the expression is
251/// outside the supported shape.
252fn emit_expr_as_wgsl(e: &Expr, src: &str, expected: WgslType) -> String {
253    let raw = || -> &str { &src[e.span().start as usize..e.span().end as usize] };
254    match e {
255        Expr::Lit(Lit {
256            value: LitValue::Int(v),
257            ..
258        }) => {
259            if matches!(expected, WgslType::F32) {
260                format!("{v}.0")
261            } else {
262                v.to_string()
263            }
264        }
265        Expr::Lit(Lit {
266            value: LitValue::Float(v),
267            ..
268        }) => {
269            // Force `.0` suffix when stringification would otherwise omit
270            // the dot, so the literal stays a float in WGSL.
271            let s = format!("{v}");
272            if s.contains('.') || s.contains('e') || s.contains('E') {
273                s
274            } else {
275                format!("{s}.0")
276            }
277        }
278        Expr::Unary(u) if matches!(u.op, UnaryOp::Neg) => {
279            format!("-{}", emit_expr_as_wgsl(&u.operand, src, expected))
280        }
281        Expr::Binary(b) => {
282            let op = match b.op {
283                BinaryOp::Add => " + ",
284                BinaryOp::Sub => " - ",
285                BinaryOp::Mul => " * ",
286                BinaryOp::Div => " / ",
287                _ => return raw().to_string(),
288            };
289            format!(
290                "({}){}({})",
291                emit_expr_as_wgsl(&b.lhs, src, expected),
292                op,
293                emit_expr_as_wgsl(&b.rhs, src, expected)
294            )
295        }
296        Expr::Call(c) => {
297            // Constructor calls map to their WGSL spellings.
298            let head = match c.callee.as_str() {
299                "float2" | "vec2" => "vec2<f32>",
300                "float3" | "vec3" => "vec3<f32>",
301                "float4" | "vec4" => "vec4<f32>",
302                other => other,
303            };
304            let args: Vec<String> = c
305                .args
306                .iter()
307                .map(|a| emit_expr_as_wgsl(a, src, WgslType::F32))
308                .collect();
309            format!("{head}({})", args.join(", "))
310        }
311        _ => raw().to_string(),
312    }
313}