onedrop_hlsl/rewrite/
brace_init.rs

1//! Pass: lower `T x = { … };` brace-init local declarations to
2//! `T x = T(…);` constructor calls.
3//!
4//! HLSL accepts brace initialisers on matrix and vector locals:
5//!
6//! ```hlsl
7//! float2x2 rot = { cos(q9), sin(q9), -sin(q9), cos(q9) };
8//! float3   c   = { 0.1, 0.2, 0.3 };
9//! ```
10//!
11//! WGSL has no brace-initialiser form for vec/mat types — both `var x:
12//! mat2x2<f32> = { … }` and the post-translation `mat2x2<f32> x = { … };`
13//! get rejected with `expected '(', found '<ident>'` because the parser
14//! is in declarator position when it hits the offending `{`.
15//!
16//! The downstream regex pipeline (`rewrite_local_declarations`,
17//! `LOCAL_DECL_REGEX`) explicitly excludes `{` from its capture group to
18//! avoid greedy-swallowing function bodies, so these decls fall through
19//! the regex untouched and naga rejects them at parse.
20//!
21//! This pass rewrites every applicable `LocalDecl` so the init is a
22//! constructor call. Targets `float2x2`/`float3x3`/`float4x4` (the
23//! dominant corpus shape — 8 warp presets on the 2 000-sample, all
24//! containing the `rot` rotation matrix pattern above) and the
25//! analogous `floatN` vector form when an explicit brace-init is used.
26//!
27//! Edits are minimal: replace just the `{ … }` span with `T(…)`. The
28//! enclosing `T x =` and the trailing `;` survive verbatim so the
29//! downstream local-decl regex sees a well-formed
30//! `T x = T(...);` shape and converts it to `var x: T = T(...);` on
31//! its next pass.
32
33use super::*;
34use crate::lex::Span;
35
36pub(crate) fn rewrite_brace_init(src: &str) -> String {
37    let Ok(tu) = parse_hlsl(src) else {
38        return src.to_string();
39    };
40    let mut edits = Vec::new();
41    if let Some(body) = &tu.shader_body {
42        walk_block(body, src, &mut edits);
43    }
44    for item in &tu.items {
45        if let Item::Function(f) = item {
46            walk_block(&f.body, src, &mut edits);
47        }
48    }
49    apply_edits(src, &mut edits)
50}
51
52fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
53    for s in &b.stmts {
54        walk_stmt(s, src, edits);
55    }
56}
57
58fn walk_stmt(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
59    match s {
60        Stmt::LocalDecl(d) => {
61            // Only lower when the init is a brace list and we recognise
62            // the type. `array_len.is_some()` is the array-decl path —
63            // `array_lower` handles those.
64            if d.array_len.is_some() {
65                return;
66            }
67            let Some(Expr::InitList(list)) = &d.init else {
68                return;
69            };
70            let ty = type_from_typeref(&d.ty);
71            let constructor = constructor_name(ty);
72            if constructor.is_empty() {
73                return;
74            }
75            let elem_count = expected_elem_count(ty);
76            if list.elems.len() != elem_count {
77                // Shape mismatch — leave alone, naga will surface a
78                // clearer error than our half-rewritten output.
79                return;
80            }
81            let body = list
82                .elems
83                .iter()
84                .map(|e| ctx_slice(src, e.span()).to_string())
85                .collect::<Vec<_>>()
86                .join(", ");
87            edits.push(TextEdit {
88                start: list.span.start,
89                end: list.span.end,
90                replacement: format!("{constructor}({body})"),
91            });
92        }
93        Stmt::Block(b) => walk_block(b, src, edits),
94        Stmt::If(i) => {
95            walk_stmt(&i.then_branch, src, edits);
96            if let Some(e) = &i.else_branch {
97                walk_stmt(e, src, edits);
98            }
99        }
100        Stmt::While(w) => walk_stmt(&w.body, src, edits),
101        Stmt::For(f) => walk_stmt(&f.body, src, edits),
102        _ => {}
103    }
104}
105
106fn ctx_slice(src: &str, span: Span) -> &str {
107    &src[span.start as usize..span.end as usize]
108}
109
110fn constructor_name(ty: WgslType) -> &'static str {
111    match ty {
112        WgslType::Vec2F => "vec2<f32>",
113        WgslType::Vec3F => "vec3<f32>",
114        WgslType::Vec4F => "vec4<f32>",
115        WgslType::Mat2F => "mat2x2<f32>",
116        WgslType::Mat3F => "mat3x3<f32>",
117        WgslType::Mat4F => "mat4x4<f32>",
118        _ => "",
119    }
120}
121
122fn expected_elem_count(ty: WgslType) -> usize {
123    match ty {
124        WgslType::Vec2F => 2,
125        WgslType::Vec3F => 3,
126        WgslType::Vec4F => 4,
127        WgslType::Mat2F => 4,
128        WgslType::Mat3F => 9,
129        WgslType::Mat4F => 16,
130        _ => 0,
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::translate_shader;
138
139    #[test]
140    fn float2x2_brace_init_becomes_mat_constructor() {
141        let src = "shader_body { float2x2 rot = { 1.0, 2.0, 3.0, 4.0 }; }";
142        let out = rewrite_brace_init(src);
143        // The brace block is replaced; the surrounding decl is untouched
144        // (the local-decl regex picks it up afterward).
145        assert!(
146            out.contains("float2x2 rot = mat2x2<f32>(1.0, 2.0, 3.0, 4.0);"),
147            "got: {out}"
148        );
149    }
150
151    #[test]
152    fn float3_brace_init_becomes_vec_constructor() {
153        let src = "shader_body { float3 c = { 0.1, 0.2, 0.3 }; }";
154        let out = rewrite_brace_init(src);
155        assert!(
156            out.contains("float3 c = vec3<f32>(0.1, 0.2, 0.3);"),
157            "got: {out}"
158        );
159    }
160
161    #[test]
162    fn translate_roundtrip_handles_rot_matrix() {
163        let hlsl = r#"
164shader_body {
165    float q9 = 0.5;
166    float2x2 rot = { cos(q9), sin(q9),
167                     -sin(q9), cos(q9) };
168    ret = float3(rot[0][0], rot[0][1], rot[1][0]);
169}
170"#;
171        let wgsl = translate_shader(hlsl).expect("translates");
172        assert!(
173            wgsl.contains("var rot: mat2x2<f32> = mat2x2<f32>("),
174            "expected `var rot: mat2x2<f32> = mat2x2<f32>(...)`, got:\n{wgsl}"
175        );
176        // No `{` survived in the init expression.
177        let after_rot = &wgsl[wgsl.find("rot").unwrap()..];
178        let line_end = after_rot.find(';').unwrap();
179        assert!(
180            !after_rot[..line_end].contains('{'),
181            "stale brace init: {wgsl}"
182        );
183    }
184
185    #[test]
186    fn shape_mismatch_left_alone() {
187        // 3 elements for a float2x2 — neither flat-by-component (would be
188        // 4) nor row-by-row (would be 2 vector elements). Leave the
189        // source alone so naga surfaces the user's authoring error.
190        let src = "shader_body { float2x2 rot = { 1.0, 2.0, 3.0 }; }";
191        let out = rewrite_brace_init(src);
192        assert_eq!(out, src);
193    }
194
195    #[test]
196    fn nested_brace_inits_in_function_body() {
197        let src = r#"
198float helper(float t) {
199    float2 v = { t, t * 2 };
200    return v.x + v.y;
201}
202shader_body { ret = float3(helper(0.5), 0, 0); }
203"#;
204        let out = rewrite_brace_init(src);
205        assert!(
206            out.contains("float2 v = vec2<f32>(t, t * 2);"),
207            "got: {out}"
208        );
209    }
210
211    #[test]
212    fn unknown_type_left_alone() {
213        // `MyStruct` isn't in our type table; brace-init survives.
214        let src = "shader_body { MyStruct s = { 1, 2 }; }";
215        let out = rewrite_brace_init(src);
216        assert_eq!(out, src);
217    }
218
219    #[test]
220    fn array_decl_left_alone() {
221        // Array-typed locals go through `array_lower`, not this pass.
222        let src = "shader_body { float arr[3] = { 1.0, 2.0, 3.0 }; }";
223        let out = rewrite_brace_init(src);
224        assert_eq!(out, src);
225    }
226}