1use super::*;
4
5pub(crate) fn rewrite_ternary_to_select(src: &str) -> String {
19 let Ok(tu) = parse_hlsl(src) else {
20 return src.to_string();
21 };
22 let mut edits = Vec::new();
23 if let Some(body) = &tu.shader_body {
24 scan_block_for_ternary(body, src, &mut edits);
25 }
26 for item in &tu.items {
27 match item {
28 Item::Function(f) => scan_block_for_ternary(&f.body, src, &mut edits),
29 Item::GlobalVar(g) => {
30 if let Some(init) = &g.init {
31 scan_expr_for_ternary(init, src, &mut edits);
32 }
33 }
34 _ => {}
35 }
36 }
37 apply_edits(src, &mut edits)
38}
39
40fn scan_block_for_ternary(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
41 for s in &b.stmts {
42 scan_stmt_for_ternary(s, src, edits);
43 }
44}
45
46fn scan_stmt_for_ternary(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
47 match s {
48 Stmt::LocalDecl(d) => {
49 if let Some(init) = &d.init {
50 scan_expr_for_ternary(init, src, edits);
51 }
52 if let Some(len) = &d.array_len {
53 scan_expr_for_ternary(len, src, edits);
54 }
55 }
56 Stmt::Assign(a) => {
57 scan_expr_for_ternary(&a.target, src, edits);
58 scan_expr_for_ternary(&a.value, src, edits);
59 }
60 Stmt::Expr(e) => scan_expr_for_ternary(e, src, edits),
61 Stmt::If(i) => {
62 scan_expr_for_ternary(&i.cond, src, edits);
63 scan_stmt_for_ternary(&i.then_branch, src, edits);
64 if let Some(b) = &i.else_branch {
65 scan_stmt_for_ternary(b, src, edits);
66 }
67 }
68 Stmt::While(w) => {
69 scan_expr_for_ternary(&w.cond, src, edits);
70 scan_stmt_for_ternary(&w.body, src, edits);
71 }
72 Stmt::For(f) => {
73 if let Some(init) = &f.init {
74 scan_stmt_for_ternary(init, src, edits);
75 }
76 if let Some(c) = &f.cond {
77 scan_expr_for_ternary(c, src, edits);
78 }
79 if let Some(st) = &f.step {
80 scan_expr_for_ternary(st, src, edits);
81 }
82 scan_stmt_for_ternary(&f.body, src, edits);
83 }
84 Stmt::Return(Some(e)) => scan_expr_for_ternary(e, src, edits),
85 Stmt::Block(b) => scan_block_for_ternary(b, src, edits),
86 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
87 }
88}
89
90fn scan_expr_for_ternary(e: &Expr, src: &str, edits: &mut Vec<TextEdit>) {
91 match e {
92 Expr::Ternary(t) => {
93 let replacement = emit_ternary_aware(e, src);
97 edits.push(TextEdit {
98 start: t.span.start,
99 end: t.span.end,
100 replacement,
101 });
102 }
103 Expr::Binary(b) => {
104 scan_expr_for_ternary(&b.lhs, src, edits);
105 scan_expr_for_ternary(&b.rhs, src, edits);
106 }
107 Expr::Unary(u) => scan_expr_for_ternary(&u.operand, src, edits),
108 Expr::Call(c) => {
109 for a in &c.args {
110 scan_expr_for_ternary(a, src, edits);
111 }
112 }
113 Expr::Member(m) => scan_expr_for_ternary(&m.base, src, edits),
114 Expr::Swizzle(s) => scan_expr_for_ternary(&s.base, src, edits),
115 Expr::Index(i) => {
116 scan_expr_for_ternary(&i.base, src, edits);
117 scan_expr_for_ternary(&i.index, src, edits);
118 }
119 Expr::InitList(l) => {
120 for e in &l.elems {
121 scan_expr_for_ternary(e, src, edits);
122 }
123 }
124 Expr::Assign(a) => {
125 scan_expr_for_ternary(&a.target, src, edits);
126 scan_expr_for_ternary(&a.value, src, edits);
127 }
128 Expr::Lit(_) | Expr::Ident(_, _) => {}
129 }
130}
131
132fn emit_ternary_aware(e: &Expr, src: &str) -> String {
139 if !subtree_has_ternary(e) {
140 return src[e.span().start as usize..e.span().end as usize].to_string();
141 }
142 match e {
143 Expr::Ternary(t) => format!(
144 "select(({}), ({}), ({}))",
145 emit_ternary_aware(&t.else_expr, src),
146 emit_ternary_aware(&t.then_expr, src),
147 emit_ternary_aware(&t.cond, src),
148 ),
149 Expr::Binary(b) => format!(
150 "({} {} {})",
151 emit_ternary_aware(&b.lhs, src),
152 binop_text(b.op),
153 emit_ternary_aware(&b.rhs, src),
154 ),
155 Expr::Unary(u) => format!("{}{}", unop_text(u.op), emit_ternary_aware(&u.operand, src)),
156 Expr::Call(c) => {
157 let args: Vec<String> = c.args.iter().map(|a| emit_ternary_aware(a, src)).collect();
158 format!("{}({})", c.callee, args.join(", "))
159 }
160 Expr::Member(m) => format!("{}.{}", emit_ternary_aware(&m.base, src), m.member),
161 Expr::Swizzle(s) => format!("{}.{}", emit_ternary_aware(&s.base, src), s.components),
162 Expr::Index(i) => format!(
163 "{}[{}]",
164 emit_ternary_aware(&i.base, src),
165 emit_ternary_aware(&i.index, src)
166 ),
167 Expr::Assign(a) => format!(
168 "{} {} {}",
169 emit_ternary_aware(&a.target, src),
170 assign_op_text(a.op),
171 emit_ternary_aware(&a.value, src)
172 ),
173 Expr::InitList(l) => {
177 let elems: Vec<String> = l.elems.iter().map(|e| emit_ternary_aware(e, src)).collect();
178 format!("{{ {} }}", elems.join(", "))
179 }
180 Expr::Lit(_) | Expr::Ident(_, _) => {
181 src[e.span().start as usize..e.span().end as usize].to_string()
182 }
183 }
184}
185
186fn subtree_has_ternary(e: &Expr) -> bool {
187 match e {
188 Expr::Ternary(_) => true,
189 Expr::Binary(b) => subtree_has_ternary(&b.lhs) || subtree_has_ternary(&b.rhs),
190 Expr::Unary(u) => subtree_has_ternary(&u.operand),
191 Expr::Call(c) => c.args.iter().any(subtree_has_ternary),
192 Expr::Member(m) => subtree_has_ternary(&m.base),
193 Expr::Swizzle(s) => subtree_has_ternary(&s.base),
194 Expr::Index(i) => subtree_has_ternary(&i.base) || subtree_has_ternary(&i.index),
195 Expr::InitList(l) => l.elems.iter().any(subtree_has_ternary),
196 Expr::Assign(a) => subtree_has_ternary(&a.target) || subtree_has_ternary(&a.value),
197 Expr::Lit(_) | Expr::Ident(_, _) => false,
198 }
199}
200
201fn binop_text(op: BinaryOp) -> &'static str {
202 match op {
203 BinaryOp::Add => "+",
204 BinaryOp::Sub => "-",
205 BinaryOp::Mul => "*",
206 BinaryOp::Div => "/",
207 BinaryOp::Rem => "%",
208 BinaryOp::Eq => "==",
209 BinaryOp::Ne => "!=",
210 BinaryOp::Lt => "<",
211 BinaryOp::Le => "<=",
212 BinaryOp::Gt => ">",
213 BinaryOp::Ge => ">=",
214 BinaryOp::And => "&&",
215 BinaryOp::Or => "||",
216 BinaryOp::BitAnd => "&",
217 BinaryOp::BitOr => "|",
218 BinaryOp::BitXor => "^",
219 BinaryOp::Shl => "<<",
220 BinaryOp::Shr => ">>",
221 }
222}
223
224fn unop_text(op: UnaryOp) -> &'static str {
225 match op {
226 UnaryOp::Neg => "-",
227 UnaryOp::Pos => "+",
228 UnaryOp::Not => "!",
229 UnaryOp::BitNot => "~",
230 }
231}
232
233fn assign_op_text(op: AssignOp) -> &'static str {
234 match op {
235 AssignOp::Set => "=",
236 AssignOp::Add => "+=",
237 AssignOp::Sub => "-=",
238 AssignOp::Mul => "*=",
239 AssignOp::Div => "/=",
240 AssignOp::Rem => "%=",
241 }
242}