onedrop_hlsl/rewrite/
chained_init.rs1use super::*;
30
31pub(crate) fn rewrite_chained_assign_inits(src: &str) -> String {
32 let Ok(tu) = parse_hlsl(src) else {
33 return src.to_string();
34 };
35 let mut edits = Vec::new();
36 if let Some(body) = &tu.shader_body {
37 walk_block(body, src, &mut edits);
38 }
39 for item in &tu.items {
40 if let Item::Function(f) = item {
41 walk_block(&f.body, src, &mut edits);
42 }
43 }
44 apply_edits(src, &mut edits)
45}
46
47fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
48 let stmts = &b.stmts;
49 let mut i = 0;
50 while i < stmts.len() {
51 if let Stmt::LocalDecl(d) = &stmts[i] {
52 let chain_len = count_preceding_synthetic(stmts, i, d.span.start);
53 if chain_len > 0 {
54 emit_chain_edit(d, &stmts[i - chain_len..i], src, edits);
55 }
56 }
57 descend(&stmts[i], src, edits);
58 i += 1;
59 }
60}
61
62fn descend(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
63 match s {
64 Stmt::If(i) => {
65 descend(&i.then_branch, src, edits);
66 if let Some(e) = &i.else_branch {
67 descend(e, src, edits);
68 }
69 }
70 Stmt::While(w) => descend(&w.body, src, edits),
71 Stmt::For(f) => descend(&f.body, src, edits),
72 Stmt::Block(b) => walk_block(b, src, edits),
73 _ => {}
74 }
75}
76
77fn count_preceding_synthetic(stmts: &[Stmt], i: usize, decl_start: u32) -> usize {
82 let mut k = 0;
83 while k < i {
84 let prev = &stmts[i - k - 1];
85 let Stmt::Assign(a) = prev else { break };
86 if a.span.start <= decl_start {
87 break;
88 }
89 k += 1;
90 }
91 k
92}
93
94fn emit_chain_edit(d: &LocalDecl, synth: &[Stmt], src: &str, edits: &mut Vec<TextEdit>) {
95 let last_value_end = match &synth[0] {
100 Stmt::Assign(a) => a.value.span().end,
101 _ => return,
102 };
103 let start = d.span.start;
104 if last_value_end <= start {
105 return; }
107
108 let mut text = String::new();
109 for stmt in synth {
111 let Stmt::Assign(a) = stmt else {
112 return;
113 };
114 text.push_str(slice(src, a.target.span()));
115 text.push_str(" = ");
116 text.push_str(slice(src, a.value.span()));
117 text.push_str("; ");
118 }
119 text.push_str(slice(src, d.span));
121 if let Some(init) = &d.init {
122 text.push_str(" = ");
123 text.push_str(slice(src, init.span()));
124 }
125
126 edits.push(TextEdit {
127 start,
128 end: last_value_end,
129 replacement: text,
130 });
131}
132
133fn slice(src: &str, sp: Span) -> &str {
134 &src[sp.start as usize..sp.end as usize]
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn two_deep_chain_lowered() {
143 let src = "shader_body { float2 ruv = uv = 0.5; }";
144 let out = rewrite_chained_assign_inits(src);
145 assert!(out.contains("uv = 0.5; float2 ruv = uv"), "got: {out}");
146 assert!(out.trim_end_matches('}').trim_end().ends_with(';'));
148 }
149
150 #[test]
151 fn three_deep_chain_lowered() {
152 let src = "shader_body { float2 ruv = uv = uv2 = 0.5; }";
153 let out = rewrite_chained_assign_inits(src);
154 assert!(
155 out.contains("uv2 = 0.5; uv = uv2; float2 ruv = uv"),
156 "got: {out}"
157 );
158 }
159
160 #[test]
161 fn hand_written_assign_then_decl_untouched() {
162 let src = "shader_body { uv = 0.5; float2 ruv = uv; }";
163 let out = rewrite_chained_assign_inits(src);
164 assert_eq!(out, src);
165 }
166
167 #[test]
168 fn realistic_corpus_shape() {
169 let src = "shader_body { float3 rc = GetBlur1(uv); \
172 float2 ruv = uv = 0.5 + (uv-0.5)*(1+(rc.y*0.05)); \
173 ret = tex2D(sampler_main, ruv).xyz; }";
174 let out = rewrite_chained_assign_inits(src);
175 assert!(
176 out.contains("uv = 0.5 + (uv-0.5)*(1+(rc.y*0.05)); float2 ruv = uv"),
177 "got: {out}"
178 );
179 assert!(out.contains("float3 rc = GetBlur1(uv);"));
181 assert!(out.contains("ret = tex2D(sampler_main, ruv).xyz;"));
182 }
183}