1use super::*;
35
36pub(crate) fn rewrite_embedded_assigns(src: &str) -> String {
37 let Ok(tu) = parse_hlsl(src) else {
38 return src.to_string();
39 };
40 let mut edits = Vec::new();
41 if let Some(body) = &tu.shader_body {
42 walk_block(body, src, &mut edits);
43 }
44 for item in &tu.items {
45 if let Item::Function(f) = item {
46 walk_block(&f.body, src, &mut edits);
47 }
48 }
49 apply_edits(src, &mut edits)
50}
51
52fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
53 for s in &b.stmts {
54 walk_stmt(s, src, edits);
55 }
56}
57
58fn walk_stmt(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
59 let stmt_start = match stmt_span_start(s, src) {
64 Some(p) => p,
65 None => return, };
67 let mut assigns: Vec<&AssignExpr> = Vec::new();
68 match s {
69 Stmt::LocalDecl(d) => {
70 if let Some(init) = &d.init {
71 collect_embedded(init, &mut assigns);
72 }
73 if let Some(arr) = &d.array_len {
74 collect_embedded(arr, &mut assigns);
75 }
76 }
77 Stmt::Assign(a) => {
78 collect_embedded(&a.target, &mut assigns);
81 collect_embedded(&a.value, &mut assigns);
82 }
83 Stmt::Expr(e) => collect_embedded(e, &mut assigns),
84 Stmt::Return(Some(e)) => collect_embedded(e, &mut assigns),
85 Stmt::If(i) => {
86 collect_embedded(&i.cond, &mut assigns);
87 walk_stmt(&i.then_branch, src, edits);
88 if let Some(e) = &i.else_branch {
89 walk_stmt(e, src, edits);
90 }
91 }
92 Stmt::While(w) => {
93 collect_embedded(&w.cond, &mut assigns);
94 walk_stmt(&w.body, src, edits);
95 }
96 Stmt::For(f) => {
97 if let Some(cond) = &f.cond {
101 collect_embedded(cond, &mut assigns);
102 }
103 walk_stmt(&f.body, src, edits);
104 if let Some(init) = &f.init {
105 walk_stmt(init, src, edits);
106 }
107 }
108 Stmt::Block(b) => walk_block(b, src, edits),
109 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
110 }
111 if assigns.is_empty() {
112 return;
113 }
114 assigns.sort_by_key(|a| a.span.start);
116 let mut prelude = String::new();
117 for a in &assigns {
118 prelude.push_str(slice(src, a.target.span()));
119 prelude.push(' ');
120 prelude.push_str(assign_op_str(a.op));
121 prelude.push(' ');
122 prelude.push_str(slice(src, a.value.span()));
123 prelude.push_str("; ");
124 edits.push(TextEdit {
126 start: a.span.start,
127 end: a.span.end,
128 replacement: slice(src, a.target.span()).to_string(),
129 });
130 }
131 edits.push(TextEdit {
132 start: stmt_start,
133 end: stmt_start,
134 replacement: prelude,
135 });
136}
137
138fn assign_op_str(op: AssignOp) -> &'static str {
139 match op {
140 AssignOp::Set => "=",
141 AssignOp::Add => "+=",
142 AssignOp::Sub => "-=",
143 AssignOp::Mul => "*=",
144 AssignOp::Div => "/=",
145 AssignOp::Rem => "%=",
146 }
147}
148
149fn stmt_span_start(s: &Stmt, src: &str) -> Option<u32> {
150 match s {
151 Stmt::LocalDecl(d) => Some(d.span.start),
152 Stmt::Assign(a) => Some(a.span.start),
153 Stmt::Expr(e) => Some(e.span().start),
154 Stmt::Return(Some(e)) => find_return_keyword_before(src, e.span().start),
159 Stmt::Return(None) => None,
160 Stmt::If(i) => Some(i.span.start),
161 Stmt::While(w) => Some(w.span.start),
162 Stmt::For(f) => Some(f.span.start),
163 Stmt::Block(b) => Some(b.span.start),
164 Stmt::Break | Stmt::Continue => None,
165 }
166}
167
168const RETURN_KW: &[u8] = b"return";
169
170fn find_return_keyword_before(src: &str, expr_start: u32) -> Option<u32> {
175 let bytes = src.as_bytes();
176 let mut i = expr_start as usize;
177 while i > 0 && bytes[i - 1].is_ascii_whitespace() {
178 i -= 1;
179 }
180 if i < RETURN_KW.len() {
181 return None;
182 }
183 let kw_start = i - RETURN_KW.len();
184 if &bytes[kw_start..i] != RETURN_KW {
185 return None;
186 }
187 if kw_start > 0 {
189 let p = bytes[kw_start - 1];
190 if p.is_ascii_alphanumeric() || p == b'_' {
191 return None;
192 }
193 }
194 Some(kw_start as u32)
195}
196
197fn collect_embedded<'e>(e: &'e Expr, out: &mut Vec<&'e AssignExpr>) {
198 match e {
199 Expr::Assign(a) => {
200 out.push(a);
203 collect_embedded(&a.value, out);
204 }
205 Expr::Binary(b) => {
206 collect_embedded(&b.lhs, out);
207 collect_embedded(&b.rhs, out);
208 }
209 Expr::Unary(u) => collect_embedded(&u.operand, out),
210 Expr::Ternary(t) => {
211 collect_embedded(&t.cond, out);
212 collect_embedded(&t.then_expr, out);
213 collect_embedded(&t.else_expr, out);
214 }
215 Expr::Call(c) => {
216 for a in &c.args {
217 collect_embedded(a, out);
218 }
219 }
220 Expr::Member(m) => collect_embedded(&m.base, out),
221 Expr::Swizzle(s) => collect_embedded(&s.base, out),
222 Expr::Index(i) => {
223 collect_embedded(&i.base, out);
224 collect_embedded(&i.index, out);
225 }
226 Expr::InitList(l) => {
227 for e in &l.elems {
228 collect_embedded(e, out);
229 }
230 }
231 Expr::Lit(_) | Expr::Ident(_, _) => {}
232 }
233}
234
235fn slice(src: &str, sp: Span) -> &str {
236 &src[sp.start as usize..sp.end as usize]
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn lift_single_assign_in_call_arg() {
245 let src = "shader_body { ret = lerp(a, tmp = b, c); }";
247 let out = rewrite_embedded_assigns(src);
248 assert!(
250 out.contains("tmp = b; ret = lerp(a, tmp, c);"),
251 "got: {out}"
252 );
253 }
254
255 #[test]
256 fn lift_inside_return() {
257 let src = "float Get(float2 uvi) { float tmp; return lerp(a, tmp = b, c); } \
258 shader_body { ret = float3(Get(uv), 0); }";
259 let out = rewrite_embedded_assigns(src);
260 assert!(
261 out.contains("tmp = b; return lerp(a, tmp, c);"),
262 "got: {out}"
263 );
264 }
265
266 #[test]
267 fn no_change_when_no_embedded_assigns() {
268 let src = "shader_body { ret = lerp(a, b, c); }";
269 let out = rewrite_embedded_assigns(src);
270 assert_eq!(out, src);
271 }
272
273 #[test]
274 fn full_whoah_warp_translates_without_embedded_assign_in_output() {
275 let src = include_str!("../../tests/whoah_warp.hlsl");
279 let wgsl = crate::translate_shader(src).expect("translate ok");
280 let bad = "tmp = GetBlur1(uvi),";
283 assert!(
284 !wgsl.contains(bad),
285 "embedded assign survived translate:\n{wgsl}"
286 );
287 }
288
289 #[test]
290 fn full_whoah_warp_parses_and_rewrites() {
291 let src = include_str!("../../tests/whoah_warp.hlsl");
294 let parsed = crate::parse::parse_hlsl(src);
295 assert!(parsed.is_ok(), "parse err: {:?}", parsed.err());
296 let out = rewrite_embedded_assigns(src);
297 assert!(
299 out.contains("tmp = GetBlur1(uvi); return lerp"),
300 "lift did not fire on Get1:\n{out}"
301 );
302 }
303
304 #[test]
305 fn corpus_shape_survives_full_translate_pipeline() {
306 let src = "float3 Get1 (float2 uvi) {float3 tmp; float2 pix; \
310 return lerp (GetPixel(uvi), tmp = GetBlur1(uvi),change*4);} \
311 shader_body { ret = Get1(uv); }";
312 let wgsl = crate::translate_shader(src).unwrap();
313 assert!(
314 !wgsl.contains("tmp = GetBlur1(uvi)") || wgsl.contains("tmp = GetBlur1(uvi);"),
315 "raw assign-as-arg leaked into WGSL: {wgsl}"
316 );
317 }
318
319 #[test]
320 fn exact_corpus_shape_whoah_get1() {
321 let src = "float3 Get1 (float2 uvi) {float3 tmp; float2 pix; \
325 return lerp (GetPixel(uvi), tmp = GetBlur1(uvi),change*4);} \
326 shader_body { ret = Get1(uv); }";
327 let out = rewrite_embedded_assigns(src);
328 assert!(
329 out.contains("tmp = GetBlur1(uvi); return lerp (GetPixel(uvi), tmp,change*4);"),
330 "got: {out}"
331 );
332 }
333
334 #[test]
335 fn top_level_assign_not_lifted() {
336 let src = "shader_body { tmp = b; ret = tmp; }";
338 let out = rewrite_embedded_assigns(src);
339 assert_eq!(out, src);
340 }
341}