onedrop_hlsl/rewrite/
vec_cmp.rs

1//! Pass: broadcast scalar operand of vec/scalar comparisons.
2
3use super::bool_arith::decl_target_float_kind;
4use super::*;
5
6// ---------------------------------------------------------------------------
7// Pass: broadcast scalar operand of vec/scalar comparisons
8// ---------------------------------------------------------------------------
9
10/// HLSL auto-broadcasts the scalar operand of a comparison
11/// (`vec3 >= 0.1` → element-wise `vec3 >= float3(0.1)`); WGSL rejects
12/// `vec3<f32> >= f32` with `InvalidBinaryOperandTypes` (a dominant
13/// validate-stage cluster, op=Ge/Le/Gt/Lt/Eq/Ne with
14/// `lhs_type: Vector`). Wraps the scalar operand in a `vecN<f32>(…)`
15/// constructor so the comparison evaluates as a vector-bool.
16///
17/// Stays conservative: only fires when one side is a *known* vector
18/// and the other is a *known* scalar f32. `(a < b) <= c` chained
19/// comparisons (where one side is already bool) are left alone — they
20/// belong to the eval long-tail, not this cluster.
21pub(crate) fn rewrite_vec_scalar_compare(src: &str) -> String {
22    let Ok(tu) = parse_hlsl(src) else {
23        return src.to_string();
24    };
25    let mut ctx = WalkCtx::new(src);
26    // Seed globals with the downstream-widened type (int → f32) rather
27    // than the literal HLSL type, mirroring `walk_stmt_for_vsc`'s
28    // body-decl path. `static const int anz = 35;` parses as a global, and
29    // after `replace_types` widens to `f32`; without this, the body's
30    // `n <= anz` lookup would see anz as I32 and the cast wrap wouldn't
31    // fire.
32    for item in &tu.items {
33        if let Item::GlobalVar(g) = item {
34            let ty = if g.array_len.is_some() {
35                WgslType::Unknown
36            } else {
37                decl_target_float_kind(&g.ty)
38            };
39            ctx.declare(&g.name, ty);
40        }
41    }
42    if let Some(body) = &tu.shader_body {
43        walk_block_for_vsc(body, &mut ctx);
44    }
45    for item in &tu.items {
46        if let Item::Function(f) = item {
47            ctx.scope_push();
48            for p in &f.params {
49                ctx.declare(&p.name, decl_target_float_kind(&p.ty));
50            }
51            walk_block_for_vsc(&f.body, &mut ctx);
52            ctx.scope_pop();
53        }
54    }
55    apply_edits(src, &mut ctx.edits)
56}
57
58fn walk_block_for_vsc(b: &Block, ctx: &mut WalkCtx) {
59    ctx.scope_push();
60    for s in &b.stmts {
61        walk_stmt_for_vsc(s, ctx);
62    }
63    ctx.scope_pop();
64}
65
66fn walk_stmt_for_vsc(s: &Stmt, ctx: &mut WalkCtx) {
67    match s {
68        Stmt::LocalDecl(d) => {
69            // Most `int x = …;` decls become `var x: f32 = …;` after
70            // `replace_types` widens int → f32, so the broadcast decision
71            // must reflect the *downstream* type. The exception is
72            // `for (int n = …; …)` loop counters, which
73            // `collect_for_int_edits` keeps as `i32`. The `For` arm below
74            // handles those specially.
75            ctx.declare(&d.name, decl_target_float_kind(&d.ty));
76            if let Some(init) = &d.init {
77                walk_expr_for_vsc(init, ctx);
78            }
79        }
80        Stmt::Assign(a) => {
81            walk_expr_for_vsc(&a.target, ctx);
82            walk_expr_for_vsc(&a.value, ctx);
83        }
84        Stmt::Expr(e) => walk_expr_for_vsc(e, ctx),
85        Stmt::If(i) => {
86            walk_expr_for_vsc(&i.cond, ctx);
87            walk_stmt_for_vsc(&i.then_branch, ctx);
88            if let Some(e) = &i.else_branch {
89                walk_stmt_for_vsc(e, ctx);
90            }
91        }
92        Stmt::While(w) => {
93            walk_expr_for_vsc(&w.cond, ctx);
94            walk_stmt_for_vsc(&w.body, ctx);
95        }
96        Stmt::For(f) => {
97            ctx.scope_push();
98            if let Some(init) = &f.init {
99                // For-loop counters keep their declared HLSL type (i32 for
100                // `int`, f32 for `float`), unlike the body-decl `int → f32`
101                // widening. Declare with `type_from_typeref` so a `for (int n
102                // = 1; n <= anz; …)` comparison sees lhs=I32 against
103                // rhs=F32 and the cast wrap fires.
104                if let Stmt::LocalDecl(d) = init.as_ref() {
105                    ctx.declare(&d.name, type_from_typeref(&d.ty));
106                    if let Some(e) = &d.init {
107                        walk_expr_for_vsc(e, ctx);
108                    }
109                } else {
110                    walk_stmt_for_vsc(init, ctx);
111                }
112            }
113            if let Some(c) = &f.cond {
114                walk_expr_for_vsc(c, ctx);
115            }
116            if let Some(st) = &f.step {
117                walk_expr_for_vsc(st, ctx);
118            }
119            walk_stmt_for_vsc(&f.body, ctx);
120            ctx.scope_pop();
121        }
122        Stmt::Return(Some(e)) => walk_expr_for_vsc(e, ctx),
123        Stmt::Block(b) => walk_block_for_vsc(b, ctx),
124        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
125    }
126}
127
128fn walk_expr_for_vsc(e: &Expr, ctx: &mut WalkCtx) {
129    match e {
130        Expr::Binary(b) => {
131            walk_expr_for_vsc(&b.lhs, ctx);
132            walk_expr_for_vsc(&b.rhs, ctx);
133            if is_compare_op(b.op) {
134                try_broadcast_scalar_compare(b, ctx);
135            } else if is_arith_op(b.op) {
136                try_promote_i32_arith(b, ctx);
137            }
138        }
139        Expr::Unary(u) => walk_expr_for_vsc(&u.operand, ctx),
140        Expr::Ternary(t) => {
141            walk_expr_for_vsc(&t.cond, ctx);
142            walk_expr_for_vsc(&t.then_expr, ctx);
143            walk_expr_for_vsc(&t.else_expr, ctx);
144        }
145        Expr::Call(c) => {
146            for a in &c.args {
147                walk_expr_for_vsc(a, ctx);
148            }
149        }
150        Expr::Member(m) => walk_expr_for_vsc(&m.base, ctx),
151        Expr::Swizzle(s) => walk_expr_for_vsc(&s.base, ctx),
152        Expr::Index(i) => {
153            walk_expr_for_vsc(&i.base, ctx);
154            walk_expr_for_vsc(&i.index, ctx);
155        }
156        Expr::InitList(l) => {
157            for e in &l.elems {
158                walk_expr_for_vsc(e, ctx);
159            }
160        }
161        Expr::Assign(a) => {
162            walk_expr_for_vsc(&a.target, ctx);
163            walk_expr_for_vsc(&a.value, ctx);
164        }
165        Expr::Lit(_) | Expr::Ident(_, _) => {}
166    }
167}
168
169fn is_compare_op(op: BinaryOp) -> bool {
170    matches!(
171        op,
172        BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge
173    )
174}
175
176fn is_arith_op(op: BinaryOp) -> bool {
177    matches!(
178        op,
179        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Rem
180    )
181}
182
183/// Promote an `i32` operand of an `f32 ⟦arith⟧ i32` binop to `f32` by
184/// wrapping it in `float(…)`. WGSL refuses mixed `f32 * i32`
185/// arithmetic; corpus shape (`martin - pixies party 2 random
186/// mosscity.milk`):
187///
188/// ```text
189/// for (int n=1; n<=anz; n++) {
190///     z = 1 - fract(1.0*n/anz - 1.0*fract(-t_rel/anz));
191/// }
192/// ```
193///
194/// After `for_init`, `n: i32`. The condition's `n <= anz` already gets
195/// a `f32(n)` cast from [`try_broadcast_scalar_compare`]; the body's
196/// `1.0*n` needs the same treatment.
197fn try_promote_i32_arith(b: &crate::ast::BinaryExpr, ctx: &mut WalkCtx) {
198    let lt = infer_type_static(&b.lhs, ctx);
199    let rt = infer_type_static(&b.rhs, ctx);
200    let int_side: &Expr = match (lt, rt) {
201        (WgslType::I32, WgslType::F32) => &b.lhs,
202        (WgslType::F32, WgslType::I32) => &b.rhs,
203        // Both i32: leave alone (pure-int arithmetic is valid).
204        // Vector mixes are handled by the comparison path / binop_vec.
205        _ => return,
206    };
207    // Skip if already a `float(…)` / `f32(…)` cast — keeps the pass
208    // idempotent.
209    if let Expr::Call(c) = int_side
210        && matches!(c.callee.as_str(), "float" | "f32")
211    {
212        return;
213    }
214    let span = int_side.span();
215    ctx.edits.push(TextEdit {
216        start: span.start,
217        end: span.start,
218        replacement: "float(".to_string(),
219    });
220    ctx.edits.push(TextEdit {
221        start: span.end,
222        end: span.end,
223        replacement: ")".to_string(),
224    });
225}
226
227/// Inspect a comparison binop's operands; if one is a known vector and
228/// the other is a known scalar f32, emit a constructor wrap around the
229/// scalar so the comparison evaluates as a vector-bool. Zero-length
230/// insertions only, so the wrap survives even when other passes have
231/// also edited the operands.
232fn try_broadcast_scalar_compare(b: &crate::ast::BinaryExpr, ctx: &mut WalkCtx) {
233    let lt = infer_type_static(&b.lhs, ctx);
234    let rt = infer_type_static(&b.rhs, ctx);
235    enum Wrap<'a> {
236        ToVec(WgslType, &'a Expr),
237        I32ToF32(&'a Expr),
238    }
239    let wrap: Wrap = match (lt.is_vec(), rt.is_vec()) {
240        // vec ⟦cmp⟧ scalar — broadcast the scalar to vec
241        (true, false) if matches!(rt, WgslType::F32 | WgslType::I32) => Wrap::ToVec(lt, &b.rhs),
242        (false, true) if matches!(lt, WgslType::F32 | WgslType::I32) => Wrap::ToVec(rt, &b.lhs),
243        // i32 ⟦cmp⟧ f32 — promote the i32 side to f32. `for (var n: i32 = …;
244        // n <= anz; …)` where `anz: f32` is the dominant idiom (`for (int n =
245        // …; n <= floatvar; …)` in HLSL).
246        (false, false) => match (lt, rt) {
247            (WgslType::I32, WgslType::F32) => Wrap::I32ToF32(&b.lhs),
248            (WgslType::F32, WgslType::I32) => Wrap::I32ToF32(&b.rhs),
249            _ => return,
250        },
251        _ => return,
252    };
253    match wrap {
254        Wrap::ToVec(vec_ty, scalar_side) => {
255            // Skip if the scalar side is already a constructor of matching
256            // width — the wrap would be a no-op and accumulate on idempotent
257            // re-runs.
258            if let Expr::Call(c) = scalar_side
259                && let Some(t) = constructor_return(&c.callee)
260                && t == vec_ty
261            {
262                return;
263            }
264            let span = scalar_side.span();
265            // Emit the HLSL-shape constructor (`float3(x)`) rather than the
266            // WGSL shape (`vec3<f32>(x)`). The downstream `replace_types`
267            // pass rewrites it to the WGSL form, and avoiding `<f32>` in
268            // our inserted text means the later `looks_like_wgsl_generic`
269            // guard in `rewrite_bool_to_float` doesn't get tripped — it
270            // interprets `<f32>` as a sign of HLSL parser misreading
271            // `vec3<f32>(…)` as a chained comparison.
272            let ctor = match vec_ty {
273                WgslType::Vec2F => "float2",
274                WgslType::Vec3F => "float3",
275                WgslType::Vec4F => "float4",
276                _ => return,
277            };
278            ctx.edits.push(TextEdit {
279                start: span.start,
280                end: span.start,
281                replacement: format!("{ctor}("),
282            });
283            ctx.edits.push(TextEdit {
284                start: span.end,
285                end: span.end,
286                replacement: ")".to_string(),
287            });
288        }
289        Wrap::I32ToF32(int_side) => {
290            // Skip if already a `float(…)` cast or `f32(…)` cast.
291            if let Expr::Call(c) = int_side
292                && matches!(c.callee.as_str(), "float" | "f32")
293            {
294                return;
295            }
296            let span = int_side.span();
297            ctx.edits.push(TextEdit {
298                start: span.start,
299                end: span.start,
300                replacement: "float(".to_string(),
301            });
302            ctx.edits.push(TextEdit {
303                start: span.end,
304                end: span.end,
305                replacement: ")".to_string(),
306            });
307        }
308    }
309}