onedrop_hlsl/rewrite/array_lower.rs
1//! Pass: HLSL array → WGSL `array<T, N>` lowering (globals + locals).
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass 2: array global lowering
7// ---------------------------------------------------------------------------
8
9/// Lower MD2-style global arrays (`const float4 samples[5] = {…};`) to a
10/// WGSL-friendly comment-block form that the regex pipeline can survive.
11/// The original `const` line confuses `replace_types` (which doesn't know
12/// about `[N]`) and `rewrite_local_declarations` (which doesn't accept
13/// init lists). We replace it with a flat `var<private>` decl in
14/// HLSL-shaped sugar plus a strip marker; the regex pass then leaves it
15/// alone because the syntactic shape is already WGSL-friendly.
16///
17/// Targets the `const float4 samples[5] = {…};` idiom (cluster K + H).
18/// Falls back to passing the source through when parsing fails or no
19/// matching globals are found.
20#[cfg(test)]
21pub(crate) fn lower_array_globals(src: &str) -> String {
22 let Ok(tu) = parse_hlsl(src) else {
23 return src.to_string();
24 };
25 let mut edits = Vec::new();
26 collect_global_array_edits(&tu, src, &mut edits);
27 apply_edits(src, &mut edits)
28}
29
30/// Walk every top-level `Item::GlobalVar` with an array length + InitList
31/// and emit one text edit per match. Shared between [`lower_array_globals`]
32/// (the standalone pass) and [`apply_lowerings`] (the combined pass used
33/// by [`apply_all`]).
34pub(crate) fn collect_global_array_edits(
35 tu: &TranslationUnit,
36 src: &str,
37 edits: &mut Vec<TextEdit>,
38) {
39 for item in &tu.items {
40 let Item::GlobalVar(g) = item else { continue };
41 if g.array_len.is_none() {
42 continue;
43 }
44 let Some(init) = &g.init else {
45 continue;
46 };
47 let Expr::InitList(list) = init else {
48 continue;
49 };
50 let comp_ty = type_from_typeref(&g.ty);
51 let comp_size = vec_size(comp_ty);
52 if comp_size == 0 {
53 continue;
54 }
55 let Some(array_len_expr) = &g.array_len else {
56 continue;
57 };
58 let Expr::Lit(Lit {
59 value: LitValue::Int(n),
60 ..
61 }) = array_len_expr
62 else {
63 continue;
64 };
65 let n = *n as usize;
66 let comp_wgsl = comp_ty.wgsl_name();
67 let mut elems: Vec<String> = Vec::new();
68 // The HLSL flat init `{ 0.0, 1.0, ... }` (component-by-component)
69 // needs grouping into vecN. We also accept the already-vecN form
70 // `{ float4(...), ... }`.
71 if list.elems.len() == n {
72 // Each element is one vector — emit as-is in WGSL constructor
73 // shape.
74 for e in &list.elems {
75 elems.push(emit_expr_as_wgsl(e, src, comp_ty));
76 }
77 } else if list.elems.len() == n * comp_size {
78 // Flat list: pack `comp_size` scalars per element.
79 for chunk in list.elems.chunks(comp_size) {
80 let parts: Vec<String> = chunk
81 .iter()
82 .map(|e| emit_expr_as_wgsl(e, src, WgslType::F32))
83 .collect();
84 elems.push(format!("{comp_wgsl}({})", parts.join(", ")));
85 }
86 } else {
87 continue; // shape mismatch, leave it alone
88 }
89 let body = elems.join(", ");
90 // Emit as WGSL `var`-initialised array. `const` would be cleaner
91 // but requires every element to be const-evaluable; `var` survives
92 // arithmetic like `11.0/3.0` and matches the address-space-free
93 // form that's legal at both module scope and inside `fs_main`
94 // (the existing pipeline drops the const-array body into the
95 // function body, so we can't assume module scope).
96 let replacement = format!(
97 "var {name}: array<{comp_wgsl}, {n}> = array<{comp_wgsl}, {n}>({body});",
98 name = g.name,
99 );
100 edits.push(TextEdit {
101 start: g.span.start,
102 end: span_includes_semi(src, g.span).end,
103 replacement,
104 });
105 }
106}
107
108/// Lower local array declarations (`float3 m[3];`, `float arr[5] = {…};`)
109/// to WGSL-valid `var` shape. The downstream regex `rewrite_local_declarations`
110/// doesn't understand the `[N]` suffix or init-lists, so the lowered form
111/// drops through to naga as-is. Targets the ougiel local-array failures
112/// and any preset with bare-typed local arrays.
113#[cfg(test)]
114pub(crate) fn lower_local_arrays(src: &str) -> String {
115 let Ok(tu) = parse_hlsl(src) else {
116 return src.to_string();
117 };
118 let mut edits = Vec::new();
119 if let Some(body) = &tu.shader_body {
120 collect_local_array_edits(body, src, &mut edits);
121 }
122 for item in &tu.items {
123 if let Item::Function(f) = item {
124 collect_local_array_edits(&f.body, src, &mut edits);
125 }
126 }
127 apply_edits(src, &mut edits)
128}
129
130pub(crate) fn collect_local_array_edits(block: &Block, src: &str, edits: &mut Vec<TextEdit>) {
131 for s in &block.stmts {
132 match s {
133 Stmt::LocalDecl(d) if d.array_len.is_some() => {
134 let Some(len_expr) = &d.array_len else {
135 continue;
136 };
137 let Expr::Lit(Lit {
138 value: LitValue::Int(n),
139 ..
140 }) = len_expr
141 else {
142 continue;
143 };
144 let comp_ty = type_from_typeref(&d.ty);
145 if comp_ty == WgslType::Unknown {
146 continue;
147 }
148 let comp_wgsl = comp_ty.wgsl_name();
149 let replacement = match &d.init {
150 None => format!("var {name}: array<{comp_wgsl}, {n}>;", name = d.name),
151 Some(Expr::InitList(l)) => {
152 // Scalar element types (`float arr[N]`) treat as
153 // 1 component per element so the flat-init branch
154 // works without a special case.
155 let comp_size = if comp_ty.is_scalar() {
156 1
157 } else {
158 vec_size(comp_ty)
159 };
160 if comp_size == 0 {
161 continue;
162 }
163 let n_usize = *n as usize;
164 let elems: Vec<String> = if comp_ty.is_scalar() && l.elems.len() == n_usize
165 {
166 l.elems
167 .iter()
168 .map(|e| emit_expr_as_wgsl(e, src, comp_ty))
169 .collect()
170 } else if l.elems.len() == n_usize {
171 // One vector per array element.
172 l.elems
173 .iter()
174 .map(|e| emit_expr_as_wgsl(e, src, comp_ty))
175 .collect()
176 } else if l.elems.len() == n_usize * comp_size {
177 l.elems
178 .chunks(comp_size)
179 .map(|chunk| {
180 let parts: Vec<String> = chunk
181 .iter()
182 .map(|e| emit_expr_as_wgsl(e, src, WgslType::F32))
183 .collect();
184 format!("{comp_wgsl}({})", parts.join(", "))
185 })
186 .collect()
187 } else {
188 continue;
189 };
190 format!(
191 "var {name}: array<{comp_wgsl}, {n}> = array<{comp_wgsl}, {n}>({});",
192 elems.join(", "),
193 name = d.name,
194 )
195 }
196 Some(_) => continue, // unusual init shape, leave alone
197 };
198 edits.push(TextEdit {
199 start: d.span.start,
200 end: span_includes_semi(src, d.span).end,
201 replacement,
202 });
203 }
204 Stmt::If(i) => {
205 if let Stmt::Block(b) = &*i.then_branch {
206 collect_local_array_edits(b, src, edits);
207 }
208 if let Some(eb) = &i.else_branch
209 && let Stmt::Block(b) = &**eb
210 {
211 collect_local_array_edits(b, src, edits);
212 }
213 }
214 Stmt::While(w) => {
215 if let Stmt::Block(b) = &*w.body {
216 collect_local_array_edits(b, src, edits);
217 }
218 }
219 Stmt::For(f) => {
220 if let Stmt::Block(b) = &*f.body {
221 collect_local_array_edits(b, src, edits);
222 }
223 }
224 Stmt::Block(b) => collect_local_array_edits(b, src, edits),
225 _ => {}
226 }
227 }
228}
229
230/// Extend a span forward to include the trailing `;` (if present). The
231/// AST's GlobalVar span ends at the name, not the semicolon — for textual
232/// replacement we want to swallow the whole declaration.
233fn span_includes_semi(src: &str, span: Span) -> Span {
234 let bytes = src.as_bytes();
235 let mut i = span.end as usize;
236 while i < bytes.len() && bytes[i] != b';' {
237 i += 1;
238 }
239 if i < bytes.len() {
240 i += 1; // include the `;`
241 }
242 Span {
243 end: i as u32,
244 ..span
245 }
246}
247
248/// Emit an HLSL expression as WGSL text — narrow subset, enough for the
249/// array-global lowering case where elements are literals or simple
250/// arithmetic. Falls back to the raw original text if the expression is
251/// outside the supported shape.
252fn emit_expr_as_wgsl(e: &Expr, src: &str, expected: WgslType) -> String {
253 let raw = || -> &str { &src[e.span().start as usize..e.span().end as usize] };
254 match e {
255 Expr::Lit(Lit {
256 value: LitValue::Int(v),
257 ..
258 }) => {
259 if matches!(expected, WgslType::F32) {
260 format!("{v}.0")
261 } else {
262 v.to_string()
263 }
264 }
265 Expr::Lit(Lit {
266 value: LitValue::Float(v),
267 ..
268 }) => {
269 // Force `.0` suffix when stringification would otherwise omit
270 // the dot, so the literal stays a float in WGSL.
271 let s = format!("{v}");
272 if s.contains('.') || s.contains('e') || s.contains('E') {
273 s
274 } else {
275 format!("{s}.0")
276 }
277 }
278 Expr::Unary(u) if matches!(u.op, UnaryOp::Neg) => {
279 format!("-{}", emit_expr_as_wgsl(&u.operand, src, expected))
280 }
281 Expr::Binary(b) => {
282 let op = match b.op {
283 BinaryOp::Add => " + ",
284 BinaryOp::Sub => " - ",
285 BinaryOp::Mul => " * ",
286 BinaryOp::Div => " / ",
287 _ => return raw().to_string(),
288 };
289 format!(
290 "({}){}({})",
291 emit_expr_as_wgsl(&b.lhs, src, expected),
292 op,
293 emit_expr_as_wgsl(&b.rhs, src, expected)
294 )
295 }
296 Expr::Call(c) => {
297 // Constructor calls map to their WGSL spellings.
298 let head = match c.callee.as_str() {
299 "float2" | "vec2" => "vec2<f32>",
300 "float3" | "vec3" => "vec3<f32>",
301 "float4" | "vec4" => "vec4<f32>",
302 other => other,
303 };
304 let args: Vec<String> = c
305 .args
306 .iter()
307 .map(|a| emit_expr_as_wgsl(a, src, WgslType::F32))
308 .collect();
309 format!("{head}({})", args.join(", "))
310 }
311 _ => raw().to_string(),
312 }
313}