onedrop_hlsl/rewrite/vec_cmp.rs
1//! Pass: broadcast scalar operand of vec/scalar comparisons.
2
3use super::bool_arith::decl_target_float_kind;
4use super::*;
5
6// ---------------------------------------------------------------------------
7// Pass: broadcast scalar operand of vec/scalar comparisons
8// ---------------------------------------------------------------------------
9
10/// HLSL auto-broadcasts the scalar operand of a comparison
11/// (`vec3 >= 0.1` → element-wise `vec3 >= float3(0.1)`); WGSL rejects
12/// `vec3<f32> >= f32` with `InvalidBinaryOperandTypes` (a dominant
13/// validate-stage cluster, op=Ge/Le/Gt/Lt/Eq/Ne with
14/// `lhs_type: Vector`). Wraps the scalar operand in a `vecN<f32>(…)`
15/// constructor so the comparison evaluates as a vector-bool.
16///
17/// Stays conservative: only fires when one side is a *known* vector
18/// and the other is a *known* scalar f32. `(a < b) <= c` chained
19/// comparisons (where one side is already bool) are left alone — they
20/// belong to the eval long-tail, not this cluster.
21pub(crate) fn rewrite_vec_scalar_compare(src: &str) -> String {
22 let Ok(tu) = parse_hlsl(src) else {
23 return src.to_string();
24 };
25 let mut ctx = WalkCtx::new(src);
26 // Seed globals with the downstream-widened type (int → f32) rather
27 // than the literal HLSL type, mirroring `walk_stmt_for_vsc`'s
28 // body-decl path. `static const int anz = 35;` parses as a global, and
29 // after `replace_types` widens to `f32`; without this, the body's
30 // `n <= anz` lookup would see anz as I32 and the cast wrap wouldn't
31 // fire.
32 for item in &tu.items {
33 if let Item::GlobalVar(g) = item {
34 let ty = if g.array_len.is_some() {
35 WgslType::Unknown
36 } else {
37 decl_target_float_kind(&g.ty)
38 };
39 ctx.declare(&g.name, ty);
40 }
41 }
42 if let Some(body) = &tu.shader_body {
43 walk_block_for_vsc(body, &mut ctx);
44 }
45 for item in &tu.items {
46 if let Item::Function(f) = item {
47 ctx.scope_push();
48 for p in &f.params {
49 ctx.declare(&p.name, decl_target_float_kind(&p.ty));
50 }
51 walk_block_for_vsc(&f.body, &mut ctx);
52 ctx.scope_pop();
53 }
54 }
55 apply_edits(src, &mut ctx.edits)
56}
57
58fn walk_block_for_vsc(b: &Block, ctx: &mut WalkCtx) {
59 ctx.scope_push();
60 for s in &b.stmts {
61 walk_stmt_for_vsc(s, ctx);
62 }
63 ctx.scope_pop();
64}
65
66fn walk_stmt_for_vsc(s: &Stmt, ctx: &mut WalkCtx) {
67 match s {
68 Stmt::LocalDecl(d) => {
69 // Most `int x = …;` decls become `var x: f32 = …;` after
70 // `replace_types` widens int → f32, so the broadcast decision
71 // must reflect the *downstream* type. The exception is
72 // `for (int n = …; …)` loop counters, which
73 // `collect_for_int_edits` keeps as `i32`. The `For` arm below
74 // handles those specially.
75 ctx.declare(&d.name, decl_target_float_kind(&d.ty));
76 if let Some(init) = &d.init {
77 walk_expr_for_vsc(init, ctx);
78 }
79 }
80 Stmt::Assign(a) => {
81 walk_expr_for_vsc(&a.target, ctx);
82 walk_expr_for_vsc(&a.value, ctx);
83 }
84 Stmt::Expr(e) => walk_expr_for_vsc(e, ctx),
85 Stmt::If(i) => {
86 walk_expr_for_vsc(&i.cond, ctx);
87 walk_stmt_for_vsc(&i.then_branch, ctx);
88 if let Some(e) = &i.else_branch {
89 walk_stmt_for_vsc(e, ctx);
90 }
91 }
92 Stmt::While(w) => {
93 walk_expr_for_vsc(&w.cond, ctx);
94 walk_stmt_for_vsc(&w.body, ctx);
95 }
96 Stmt::For(f) => {
97 ctx.scope_push();
98 if let Some(init) = &f.init {
99 // For-loop counters keep their declared HLSL type (i32 for
100 // `int`, f32 for `float`), unlike the body-decl `int → f32`
101 // widening. Declare with `type_from_typeref` so a `for (int n
102 // = 1; n <= anz; …)` comparison sees lhs=I32 against
103 // rhs=F32 and the cast wrap fires.
104 if let Stmt::LocalDecl(d) = init.as_ref() {
105 ctx.declare(&d.name, type_from_typeref(&d.ty));
106 if let Some(e) = &d.init {
107 walk_expr_for_vsc(e, ctx);
108 }
109 } else {
110 walk_stmt_for_vsc(init, ctx);
111 }
112 }
113 if let Some(c) = &f.cond {
114 walk_expr_for_vsc(c, ctx);
115 }
116 if let Some(st) = &f.step {
117 walk_expr_for_vsc(st, ctx);
118 }
119 walk_stmt_for_vsc(&f.body, ctx);
120 ctx.scope_pop();
121 }
122 Stmt::Return(Some(e)) => walk_expr_for_vsc(e, ctx),
123 Stmt::Block(b) => walk_block_for_vsc(b, ctx),
124 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
125 }
126}
127
128fn walk_expr_for_vsc(e: &Expr, ctx: &mut WalkCtx) {
129 match e {
130 Expr::Binary(b) => {
131 walk_expr_for_vsc(&b.lhs, ctx);
132 walk_expr_for_vsc(&b.rhs, ctx);
133 if is_compare_op(b.op) {
134 try_broadcast_scalar_compare(b, ctx);
135 } else if is_arith_op(b.op) {
136 try_promote_i32_arith(b, ctx);
137 }
138 }
139 Expr::Unary(u) => walk_expr_for_vsc(&u.operand, ctx),
140 Expr::Ternary(t) => {
141 walk_expr_for_vsc(&t.cond, ctx);
142 walk_expr_for_vsc(&t.then_expr, ctx);
143 walk_expr_for_vsc(&t.else_expr, ctx);
144 }
145 Expr::Call(c) => {
146 for a in &c.args {
147 walk_expr_for_vsc(a, ctx);
148 }
149 }
150 Expr::Member(m) => walk_expr_for_vsc(&m.base, ctx),
151 Expr::Swizzle(s) => walk_expr_for_vsc(&s.base, ctx),
152 Expr::Index(i) => {
153 walk_expr_for_vsc(&i.base, ctx);
154 walk_expr_for_vsc(&i.index, ctx);
155 }
156 Expr::InitList(l) => {
157 for e in &l.elems {
158 walk_expr_for_vsc(e, ctx);
159 }
160 }
161 Expr::Assign(a) => {
162 walk_expr_for_vsc(&a.target, ctx);
163 walk_expr_for_vsc(&a.value, ctx);
164 }
165 Expr::Lit(_) | Expr::Ident(_, _) => {}
166 }
167}
168
169fn is_compare_op(op: BinaryOp) -> bool {
170 matches!(
171 op,
172 BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge
173 )
174}
175
176fn is_arith_op(op: BinaryOp) -> bool {
177 matches!(
178 op,
179 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Rem
180 )
181}
182
183/// Promote an `i32` operand of an `f32 ⟦arith⟧ i32` binop to `f32` by
184/// wrapping it in `float(…)`. WGSL refuses mixed `f32 * i32`
185/// arithmetic; corpus shape (`martin - pixies party 2 random
186/// mosscity.milk`):
187///
188/// ```text
189/// for (int n=1; n<=anz; n++) {
190/// z = 1 - fract(1.0*n/anz - 1.0*fract(-t_rel/anz));
191/// }
192/// ```
193///
194/// After `for_init`, `n: i32`. The condition's `n <= anz` already gets
195/// a `f32(n)` cast from [`try_broadcast_scalar_compare`]; the body's
196/// `1.0*n` needs the same treatment.
197fn try_promote_i32_arith(b: &crate::ast::BinaryExpr, ctx: &mut WalkCtx) {
198 let lt = infer_type_static(&b.lhs, ctx);
199 let rt = infer_type_static(&b.rhs, ctx);
200 let int_side: &Expr = match (lt, rt) {
201 (WgslType::I32, WgslType::F32) => &b.lhs,
202 (WgslType::F32, WgslType::I32) => &b.rhs,
203 // Both i32: leave alone (pure-int arithmetic is valid).
204 // Vector mixes are handled by the comparison path / binop_vec.
205 _ => return,
206 };
207 // Skip if already a `float(…)` / `f32(…)` cast — keeps the pass
208 // idempotent.
209 if let Expr::Call(c) = int_side
210 && matches!(c.callee.as_str(), "float" | "f32")
211 {
212 return;
213 }
214 let span = int_side.span();
215 ctx.edits.push(TextEdit {
216 start: span.start,
217 end: span.start,
218 replacement: "float(".to_string(),
219 });
220 ctx.edits.push(TextEdit {
221 start: span.end,
222 end: span.end,
223 replacement: ")".to_string(),
224 });
225}
226
227/// Inspect a comparison binop's operands; if one is a known vector and
228/// the other is a known scalar f32, emit a constructor wrap around the
229/// scalar so the comparison evaluates as a vector-bool. Zero-length
230/// insertions only, so the wrap survives even when other passes have
231/// also edited the operands.
232fn try_broadcast_scalar_compare(b: &crate::ast::BinaryExpr, ctx: &mut WalkCtx) {
233 let lt = infer_type_static(&b.lhs, ctx);
234 let rt = infer_type_static(&b.rhs, ctx);
235 enum Wrap<'a> {
236 ToVec(WgslType, &'a Expr),
237 I32ToF32(&'a Expr),
238 }
239 let wrap: Wrap = match (lt.is_vec(), rt.is_vec()) {
240 // vec ⟦cmp⟧ scalar — broadcast the scalar to vec
241 (true, false) if matches!(rt, WgslType::F32 | WgslType::I32) => Wrap::ToVec(lt, &b.rhs),
242 (false, true) if matches!(lt, WgslType::F32 | WgslType::I32) => Wrap::ToVec(rt, &b.lhs),
243 // i32 ⟦cmp⟧ f32 — promote the i32 side to f32. `for (var n: i32 = …;
244 // n <= anz; …)` where `anz: f32` is the dominant idiom (`for (int n =
245 // …; n <= floatvar; …)` in HLSL).
246 (false, false) => match (lt, rt) {
247 (WgslType::I32, WgslType::F32) => Wrap::I32ToF32(&b.lhs),
248 (WgslType::F32, WgslType::I32) => Wrap::I32ToF32(&b.rhs),
249 _ => return,
250 },
251 _ => return,
252 };
253 match wrap {
254 Wrap::ToVec(vec_ty, scalar_side) => {
255 // Skip if the scalar side is already a constructor of matching
256 // width — the wrap would be a no-op and accumulate on idempotent
257 // re-runs.
258 if let Expr::Call(c) = scalar_side
259 && let Some(t) = constructor_return(&c.callee)
260 && t == vec_ty
261 {
262 return;
263 }
264 let span = scalar_side.span();
265 // Emit the HLSL-shape constructor (`float3(x)`) rather than the
266 // WGSL shape (`vec3<f32>(x)`). The downstream `replace_types`
267 // pass rewrites it to the WGSL form, and avoiding `<f32>` in
268 // our inserted text means the later `looks_like_wgsl_generic`
269 // guard in `rewrite_bool_to_float` doesn't get tripped — it
270 // interprets `<f32>` as a sign of HLSL parser misreading
271 // `vec3<f32>(…)` as a chained comparison.
272 let ctor = match vec_ty {
273 WgslType::Vec2F => "float2",
274 WgslType::Vec3F => "float3",
275 WgslType::Vec4F => "float4",
276 _ => return,
277 };
278 ctx.edits.push(TextEdit {
279 start: span.start,
280 end: span.start,
281 replacement: format!("{ctor}("),
282 });
283 ctx.edits.push(TextEdit {
284 start: span.end,
285 end: span.end,
286 replacement: ")".to_string(),
287 });
288 }
289 Wrap::I32ToF32(int_side) => {
290 // Skip if already a `float(…)` cast or `f32(…)` cast.
291 if let Expr::Call(c) = int_side
292 && matches!(c.callee.as_str(), "float" | "f32")
293 {
294 return;
295 }
296 let span = int_side.span();
297 ctx.edits.push(TextEdit {
298 start: span.start,
299 end: span.start,
300 replacement: "float(".to_string(),
301 });
302 ctx.edits.push(TextEdit {
303 start: span.end,
304 end: span.end,
305 replacement: ")".to_string(),
306 });
307 }
308 }
309}