onedrop_hlsl/rewrite/
qa_stub.rs1use super::*;
4
5pub(crate) fn inject_qa_stub(src: &str) -> String {
31 let Ok(tu) = parse_hlsl(src) else {
32 return src.to_string();
33 };
34 let mut uses = QaUses::default();
35 if let Some(body) = &tu.shader_body {
36 scan_block_for_qa(body, &mut uses);
37 }
38 for item in &tu.items {
39 if let Item::Function(f) = item {
40 scan_block_for_qa(&f.body, &mut uses);
41 }
42 }
43 if uses.referenced_mask == 0 {
44 return src.to_string();
45 }
46 let Some(body) = &tu.shader_body else {
47 return src.to_string();
48 };
49
50 let mut edits: Vec<TextEdit> = uses
51 .mat_call_spans
52 .iter()
53 .map(|(sp, letter)| TextEdit {
54 start: sp.start,
55 end: sp.end,
56 replacement: format!(
57 "float2x2(_q{letter}.x, _q{letter}.y, _q{letter}.z, _q{letter}.w)"
58 ),
59 })
60 .collect();
61
62 if let Some(open) = find_brace_after(src, body.span.start) {
63 let mut stub = String::new();
64 for (idx, letter) in QA_LETTERS.iter().enumerate() {
65 if uses.referenced_mask & (1u8 << idx) != 0 {
66 stub.push_str(&format!(" float4 _q{letter} = float4(0.0, 0.0, 0.0, 0.0);"));
67 }
68 }
69 edits.push(TextEdit {
70 start: open,
71 end: open,
72 replacement: stub,
73 });
74 }
75 apply_edits(src, &mut edits)
76}
77
78const QA_LETTERS: [char; 8] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'];
80
81fn qa_letter_index(name: &str) -> Option<usize> {
83 let bytes = name.as_bytes();
84 if bytes.len() != 3 || bytes[0] != b'_' || bytes[1] != b'q' {
85 return None;
86 }
87 let c = bytes[2];
88 if (b'a'..=b'h').contains(&c) {
89 Some((c - b'a') as usize)
90 } else {
91 None
92 }
93}
94
95#[derive(Default)]
96struct QaUses {
97 referenced_mask: u8,
99 mat_call_spans: Vec<(Span, char)>,
102}
103
104fn scan_block_for_qa(b: &Block, out: &mut QaUses) {
105 for s in &b.stmts {
106 scan_stmt_for_qa(s, out);
107 }
108}
109
110fn scan_stmt_for_qa(s: &Stmt, out: &mut QaUses) {
111 match s {
112 Stmt::LocalDecl(d) => {
113 if let Some(init) = &d.init {
114 scan_expr_for_qa(init, out);
115 }
116 }
117 Stmt::Assign(a) => {
118 scan_expr_for_qa(&a.target, out);
119 scan_expr_for_qa(&a.value, out);
120 }
121 Stmt::Expr(e) => scan_expr_for_qa(e, out),
122 Stmt::If(i) => {
123 scan_expr_for_qa(&i.cond, out);
124 scan_stmt_for_qa(&i.then_branch, out);
125 if let Some(e) = &i.else_branch {
126 scan_stmt_for_qa(e, out);
127 }
128 }
129 Stmt::While(w) => {
130 scan_expr_for_qa(&w.cond, out);
131 scan_stmt_for_qa(&w.body, out);
132 }
133 Stmt::For(f) => {
134 if let Some(init) = &f.init {
135 scan_stmt_for_qa(init, out);
136 }
137 if let Some(c) = &f.cond {
138 scan_expr_for_qa(c, out);
139 }
140 if let Some(st) = &f.step {
141 scan_expr_for_qa(st, out);
142 }
143 scan_stmt_for_qa(&f.body, out);
144 }
145 Stmt::Return(Some(e)) => scan_expr_for_qa(e, out),
146 Stmt::Block(b) => scan_block_for_qa(b, out),
147 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
148 }
149}
150
151fn scan_expr_for_qa(e: &Expr, out: &mut QaUses) {
152 match e {
153 Expr::Ident(name, _) => {
154 if let Some(idx) = qa_letter_index(name) {
155 out.referenced_mask |= 1u8 << idx;
156 }
157 }
158 Expr::Lit(_) => {}
159 Expr::Binary(b) => {
160 scan_expr_for_qa(&b.lhs, out);
161 scan_expr_for_qa(&b.rhs, out);
162 }
163 Expr::Unary(u) => scan_expr_for_qa(&u.operand, out),
164 Expr::Ternary(t) => {
165 scan_expr_for_qa(&t.cond, out);
166 scan_expr_for_qa(&t.then_expr, out);
167 scan_expr_for_qa(&t.else_expr, out);
168 }
169 Expr::Call(c) => {
170 if c.callee == "float2x2"
171 && c.args.len() == 1
172 && let Expr::Ident(n, _) = &c.args[0]
173 && let Some(idx) = qa_letter_index(n)
174 {
175 let letter = QA_LETTERS[idx];
176 out.mat_call_spans.push((c.span, letter));
177 out.referenced_mask |= 1u8 << idx;
178 return;
179 }
180 for a in &c.args {
181 scan_expr_for_qa(a, out);
182 }
183 }
184 Expr::Member(m) => scan_expr_for_qa(&m.base, out),
185 Expr::Swizzle(s) => scan_expr_for_qa(&s.base, out),
186 Expr::Index(i) => {
187 scan_expr_for_qa(&i.base, out);
188 scan_expr_for_qa(&i.index, out);
189 }
190 Expr::InitList(l) => {
191 for e in &l.elems {
192 scan_expr_for_qa(e, out);
193 }
194 }
195 Expr::Assign(a) => {
196 scan_expr_for_qa(&a.target, out);
197 scan_expr_for_qa(&a.value, out);
198 }
199 }
200}
201
202fn find_brace_after(src: &str, start: u32) -> Option<u32> {
207 let bytes = src.as_bytes();
208 let mut i = start as usize;
209 while i < bytes.len() && bytes[i] != b'{' {
210 i += 1;
211 }
212 if i < bytes.len() {
213 Some((i + 1) as u32)
214 } else {
215 None
216 }
217}