onedrop_hlsl/rewrite/
brace_init.rs1use super::*;
34use crate::lex::Span;
35
36pub(crate) fn rewrite_brace_init(src: &str) -> String {
37 let Ok(tu) = parse_hlsl(src) else {
38 return src.to_string();
39 };
40 let mut edits = Vec::new();
41 if let Some(body) = &tu.shader_body {
42 walk_block(body, src, &mut edits);
43 }
44 for item in &tu.items {
45 if let Item::Function(f) = item {
46 walk_block(&f.body, src, &mut edits);
47 }
48 }
49 apply_edits(src, &mut edits)
50}
51
52fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
53 for s in &b.stmts {
54 walk_stmt(s, src, edits);
55 }
56}
57
58fn walk_stmt(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
59 match s {
60 Stmt::LocalDecl(d) => {
61 if d.array_len.is_some() {
65 return;
66 }
67 let Some(Expr::InitList(list)) = &d.init else {
68 return;
69 };
70 let ty = type_from_typeref(&d.ty);
71 let constructor = constructor_name(ty);
72 if constructor.is_empty() {
73 return;
74 }
75 let elem_count = expected_elem_count(ty);
76 if list.elems.len() != elem_count {
77 return;
80 }
81 let body = list
82 .elems
83 .iter()
84 .map(|e| ctx_slice(src, e.span()).to_string())
85 .collect::<Vec<_>>()
86 .join(", ");
87 edits.push(TextEdit {
88 start: list.span.start,
89 end: list.span.end,
90 replacement: format!("{constructor}({body})"),
91 });
92 }
93 Stmt::Block(b) => walk_block(b, src, edits),
94 Stmt::If(i) => {
95 walk_stmt(&i.then_branch, src, edits);
96 if let Some(e) = &i.else_branch {
97 walk_stmt(e, src, edits);
98 }
99 }
100 Stmt::While(w) => walk_stmt(&w.body, src, edits),
101 Stmt::For(f) => walk_stmt(&f.body, src, edits),
102 _ => {}
103 }
104}
105
106fn ctx_slice(src: &str, span: Span) -> &str {
107 &src[span.start as usize..span.end as usize]
108}
109
110fn constructor_name(ty: WgslType) -> &'static str {
111 match ty {
112 WgslType::Vec2F => "vec2<f32>",
113 WgslType::Vec3F => "vec3<f32>",
114 WgslType::Vec4F => "vec4<f32>",
115 WgslType::Mat2F => "mat2x2<f32>",
116 WgslType::Mat3F => "mat3x3<f32>",
117 WgslType::Mat4F => "mat4x4<f32>",
118 _ => "",
119 }
120}
121
122fn expected_elem_count(ty: WgslType) -> usize {
123 match ty {
124 WgslType::Vec2F => 2,
125 WgslType::Vec3F => 3,
126 WgslType::Vec4F => 4,
127 WgslType::Mat2F => 4,
128 WgslType::Mat3F => 9,
129 WgslType::Mat4F => 16,
130 _ => 0,
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::translate_shader;
138
139 #[test]
140 fn float2x2_brace_init_becomes_mat_constructor() {
141 let src = "shader_body { float2x2 rot = { 1.0, 2.0, 3.0, 4.0 }; }";
142 let out = rewrite_brace_init(src);
143 assert!(
146 out.contains("float2x2 rot = mat2x2<f32>(1.0, 2.0, 3.0, 4.0);"),
147 "got: {out}"
148 );
149 }
150
151 #[test]
152 fn float3_brace_init_becomes_vec_constructor() {
153 let src = "shader_body { float3 c = { 0.1, 0.2, 0.3 }; }";
154 let out = rewrite_brace_init(src);
155 assert!(
156 out.contains("float3 c = vec3<f32>(0.1, 0.2, 0.3);"),
157 "got: {out}"
158 );
159 }
160
161 #[test]
162 fn translate_roundtrip_handles_rot_matrix() {
163 let hlsl = r#"
164shader_body {
165 float q9 = 0.5;
166 float2x2 rot = { cos(q9), sin(q9),
167 -sin(q9), cos(q9) };
168 ret = float3(rot[0][0], rot[0][1], rot[1][0]);
169}
170"#;
171 let wgsl = translate_shader(hlsl).expect("translates");
172 assert!(
173 wgsl.contains("var rot: mat2x2<f32> = mat2x2<f32>("),
174 "expected `var rot: mat2x2<f32> = mat2x2<f32>(...)`, got:\n{wgsl}"
175 );
176 let after_rot = &wgsl[wgsl.find("rot").unwrap()..];
178 let line_end = after_rot.find(';').unwrap();
179 assert!(
180 !after_rot[..line_end].contains('{'),
181 "stale brace init: {wgsl}"
182 );
183 }
184
185 #[test]
186 fn shape_mismatch_left_alone() {
187 let src = "shader_body { float2x2 rot = { 1.0, 2.0, 3.0 }; }";
191 let out = rewrite_brace_init(src);
192 assert_eq!(out, src);
193 }
194
195 #[test]
196 fn nested_brace_inits_in_function_body() {
197 let src = r#"
198float helper(float t) {
199 float2 v = { t, t * 2 };
200 return v.x + v.y;
201}
202shader_body { ret = float3(helper(0.5), 0, 0); }
203"#;
204 let out = rewrite_brace_init(src);
205 assert!(
206 out.contains("float2 v = vec2<f32>(t, t * 2);"),
207 "got: {out}"
208 );
209 }
210
211 #[test]
212 fn unknown_type_left_alone() {
213 let src = "shader_body { MyStruct s = { 1, 2 }; }";
215 let out = rewrite_brace_init(src);
216 assert_eq!(out, src);
217 }
218
219 #[test]
220 fn array_decl_left_alone() {
221 let src = "shader_body { float arr[3] = { 1.0, 2.0, 3.0 }; }";
223 let out = rewrite_brace_init(src);
224 assert_eq!(out, src);
225 }
226}