onedrop_hlsl/rewrite/
ternary.rs

1//! Pass: ternary `cond ? a : b` → WGSL `select(b, a, cond)`.
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass: ternary `?:` → WGSL `select(...)`
7// ---------------------------------------------------------------------------
8
9/// Lower every `cond ? a : b` to `select(b, a, cond)`. WGSL has no
10/// ternary operator and the regex translator doesn't recognise the
11/// `?`/`:` tokens. Many presets hit `expected ';'; found '?'` before this
12/// pass.
13///
14/// Handles nested ternaries by emitting one edit per *top-level* ternary
15/// (in source order) and rendering inner ternaries inside the replacement
16/// string via a recursive expression emitter. That keeps text edits
17/// non-overlapping, which the [`apply_edits`] machinery requires.
18pub(crate) fn rewrite_ternary_to_select(src: &str) -> String {
19    let Ok(tu) = parse_hlsl(src) else {
20        return src.to_string();
21    };
22    let mut edits = Vec::new();
23    if let Some(body) = &tu.shader_body {
24        scan_block_for_ternary(body, src, &mut edits);
25    }
26    for item in &tu.items {
27        match item {
28            Item::Function(f) => scan_block_for_ternary(&f.body, src, &mut edits),
29            Item::GlobalVar(g) => {
30                if let Some(init) = &g.init {
31                    scan_expr_for_ternary(init, src, &mut edits);
32                }
33            }
34            _ => {}
35        }
36    }
37    apply_edits(src, &mut edits)
38}
39
40fn scan_block_for_ternary(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
41    for s in &b.stmts {
42        scan_stmt_for_ternary(s, src, edits);
43    }
44}
45
46fn scan_stmt_for_ternary(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
47    match s {
48        Stmt::LocalDecl(d) => {
49            if let Some(init) = &d.init {
50                scan_expr_for_ternary(init, src, edits);
51            }
52            if let Some(len) = &d.array_len {
53                scan_expr_for_ternary(len, src, edits);
54            }
55        }
56        Stmt::Assign(a) => {
57            scan_expr_for_ternary(&a.target, src, edits);
58            scan_expr_for_ternary(&a.value, src, edits);
59        }
60        Stmt::Expr(e) => scan_expr_for_ternary(e, src, edits),
61        Stmt::If(i) => {
62            scan_expr_for_ternary(&i.cond, src, edits);
63            scan_stmt_for_ternary(&i.then_branch, src, edits);
64            if let Some(b) = &i.else_branch {
65                scan_stmt_for_ternary(b, src, edits);
66            }
67        }
68        Stmt::While(w) => {
69            scan_expr_for_ternary(&w.cond, src, edits);
70            scan_stmt_for_ternary(&w.body, src, edits);
71        }
72        Stmt::For(f) => {
73            if let Some(init) = &f.init {
74                scan_stmt_for_ternary(init, src, edits);
75            }
76            if let Some(c) = &f.cond {
77                scan_expr_for_ternary(c, src, edits);
78            }
79            if let Some(st) = &f.step {
80                scan_expr_for_ternary(st, src, edits);
81            }
82            scan_stmt_for_ternary(&f.body, src, edits);
83        }
84        Stmt::Return(Some(e)) => scan_expr_for_ternary(e, src, edits),
85        Stmt::Block(b) => scan_block_for_ternary(b, src, edits),
86        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
87    }
88}
89
90fn scan_expr_for_ternary(e: &Expr, src: &str, edits: &mut Vec<TextEdit>) {
91    match e {
92        Expr::Ternary(t) => {
93            // One edit for the whole ternary span. Recursive emit handles
94            // nested ternaries, so we never descend into the children
95            // here (doing so would double-emit and overlap).
96            let replacement = emit_ternary_aware(e, src);
97            edits.push(TextEdit {
98                start: t.span.start,
99                end: t.span.end,
100                replacement,
101            });
102        }
103        Expr::Binary(b) => {
104            scan_expr_for_ternary(&b.lhs, src, edits);
105            scan_expr_for_ternary(&b.rhs, src, edits);
106        }
107        Expr::Unary(u) => scan_expr_for_ternary(&u.operand, src, edits),
108        Expr::Call(c) => {
109            for a in &c.args {
110                scan_expr_for_ternary(a, src, edits);
111            }
112        }
113        Expr::Member(m) => scan_expr_for_ternary(&m.base, src, edits),
114        Expr::Swizzle(s) => scan_expr_for_ternary(&s.base, src, edits),
115        Expr::Index(i) => {
116            scan_expr_for_ternary(&i.base, src, edits);
117            scan_expr_for_ternary(&i.index, src, edits);
118        }
119        Expr::InitList(l) => {
120            for e in &l.elems {
121                scan_expr_for_ternary(e, src, edits);
122            }
123        }
124        Expr::Assign(a) => {
125            scan_expr_for_ternary(&a.target, src, edits);
126            scan_expr_for_ternary(&a.value, src, edits);
127        }
128        Expr::Lit(_) | Expr::Ident(_, _) => {}
129    }
130}
131
132/// Render an expression to text, replacing any ternaries in its subtree
133/// with `select(...)`. Non-ternary nodes that contain ternaries get
134/// rebuilt from the AST so the ternary substitution can flow up; pure
135/// non-ternary subtrees fall back to the original source slice (which
136/// preserves whitespace, comments and HLSL idioms that downstream regex
137/// passes still need to see verbatim).
138fn emit_ternary_aware(e: &Expr, src: &str) -> String {
139    if !subtree_has_ternary(e) {
140        return src[e.span().start as usize..e.span().end as usize].to_string();
141    }
142    match e {
143        Expr::Ternary(t) => format!(
144            "select(({}), ({}), ({}))",
145            emit_ternary_aware(&t.else_expr, src),
146            emit_ternary_aware(&t.then_expr, src),
147            emit_ternary_aware(&t.cond, src),
148        ),
149        Expr::Binary(b) => format!(
150            "({} {} {})",
151            emit_ternary_aware(&b.lhs, src),
152            binop_text(b.op),
153            emit_ternary_aware(&b.rhs, src),
154        ),
155        Expr::Unary(u) => format!("{}{}", unop_text(u.op), emit_ternary_aware(&u.operand, src)),
156        Expr::Call(c) => {
157            let args: Vec<String> = c.args.iter().map(|a| emit_ternary_aware(a, src)).collect();
158            format!("{}({})", c.callee, args.join(", "))
159        }
160        Expr::Member(m) => format!("{}.{}", emit_ternary_aware(&m.base, src), m.member),
161        Expr::Swizzle(s) => format!("{}.{}", emit_ternary_aware(&s.base, src), s.components),
162        Expr::Index(i) => format!(
163            "{}[{}]",
164            emit_ternary_aware(&i.base, src),
165            emit_ternary_aware(&i.index, src)
166        ),
167        Expr::Assign(a) => format!(
168            "{} {} {}",
169            emit_ternary_aware(&a.target, src),
170            assign_op_text(a.op),
171            emit_ternary_aware(&a.value, src)
172        ),
173        // Lit/Ident/InitList don't (or rarely) contain ternaries; the
174        // `subtree_has_ternary` short-circuit above catches the common
175        // cases, but InitList could still hit here for nested inits.
176        Expr::InitList(l) => {
177            let elems: Vec<String> = l.elems.iter().map(|e| emit_ternary_aware(e, src)).collect();
178            format!("{{ {} }}", elems.join(", "))
179        }
180        Expr::Lit(_) | Expr::Ident(_, _) => {
181            src[e.span().start as usize..e.span().end as usize].to_string()
182        }
183    }
184}
185
186fn subtree_has_ternary(e: &Expr) -> bool {
187    match e {
188        Expr::Ternary(_) => true,
189        Expr::Binary(b) => subtree_has_ternary(&b.lhs) || subtree_has_ternary(&b.rhs),
190        Expr::Unary(u) => subtree_has_ternary(&u.operand),
191        Expr::Call(c) => c.args.iter().any(subtree_has_ternary),
192        Expr::Member(m) => subtree_has_ternary(&m.base),
193        Expr::Swizzle(s) => subtree_has_ternary(&s.base),
194        Expr::Index(i) => subtree_has_ternary(&i.base) || subtree_has_ternary(&i.index),
195        Expr::InitList(l) => l.elems.iter().any(subtree_has_ternary),
196        Expr::Assign(a) => subtree_has_ternary(&a.target) || subtree_has_ternary(&a.value),
197        Expr::Lit(_) | Expr::Ident(_, _) => false,
198    }
199}
200
201fn binop_text(op: BinaryOp) -> &'static str {
202    match op {
203        BinaryOp::Add => "+",
204        BinaryOp::Sub => "-",
205        BinaryOp::Mul => "*",
206        BinaryOp::Div => "/",
207        BinaryOp::Rem => "%",
208        BinaryOp::Eq => "==",
209        BinaryOp::Ne => "!=",
210        BinaryOp::Lt => "<",
211        BinaryOp::Le => "<=",
212        BinaryOp::Gt => ">",
213        BinaryOp::Ge => ">=",
214        BinaryOp::And => "&&",
215        BinaryOp::Or => "||",
216        BinaryOp::BitAnd => "&",
217        BinaryOp::BitOr => "|",
218        BinaryOp::BitXor => "^",
219        BinaryOp::Shl => "<<",
220        BinaryOp::Shr => ">>",
221    }
222}
223
224fn unop_text(op: UnaryOp) -> &'static str {
225    match op {
226        UnaryOp::Neg => "-",
227        UnaryOp::Pos => "+",
228        UnaryOp::Not => "!",
229        UnaryOp::BitNot => "~",
230    }
231}
232
233fn assign_op_text(op: AssignOp) -> &'static str {
234    match op {
235        AssignOp::Set => "=",
236        AssignOp::Add => "+=",
237        AssignOp::Sub => "-=",
238        AssignOp::Mul => "*=",
239        AssignOp::Div => "/=",
240        AssignOp::Rem => "%=",
241    }
242}