onedrop_hlsl/rewrite/bool_arith.rs
1//! Pass: bool RHS in f32 context → `select(0.0, 1.0, cond)`.
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass: bool RHS in f32 context → `select(0.0, 1.0, cond)`
7// ---------------------------------------------------------------------------
8
9/// HLSL coerces bool → float silently (`float mask = uv.x > 0.5;` yields
10/// 0.0/1.0). WGSL refuses with `the type of 'mask' is expected to be
11/// 'f32'; but got 'bool'` (or `cannot convert elements of vec3<bool> to
12/// f32` for vector comparisons). One of the largest comp failure
13/// clusters in the corpus.
14///
15/// Strategy: walk `LocalDecl`/`Assign` whose target type resolves to
16/// `f32` / `vec_n<f32>`; when the RHS is a comparison or short-circuit
17/// boolean, wrap it in `select(0.0, 1.0, <cond>)` (or `select(vec3<f32>(0.0),
18/// vec3<f32>(1.0), <cond>)` for vector). Stays conservative: only fires when
19/// both the target type and the RHS top-level operator are unambiguous.
20pub(crate) fn rewrite_bool_to_float(src: &str) -> String {
21 let Ok(tu) = parse_hlsl(src) else {
22 return src.to_string();
23 };
24 let mut ctx = WalkCtx::new(src);
25 ctx.seed_globals(&tu);
26 if let Some(body) = &tu.shader_body {
27 walk_block_for_bool(body, &mut ctx);
28 }
29 for item in &tu.items {
30 if let Item::Function(f) = item {
31 ctx.scope_push();
32 for p in &f.params {
33 ctx.declare(&p.name, type_from_typeref(&p.ty));
34 }
35 walk_block_for_bool(&f.body, &mut ctx);
36 ctx.scope_pop();
37 }
38 }
39 apply_edits(src, &mut ctx.edits)
40}
41
42fn walk_block_for_bool(b: &Block, ctx: &mut WalkCtx) {
43 ctx.scope_push();
44 for s in &b.stmts {
45 walk_stmt_for_bool(s, ctx);
46 }
47 ctx.scope_pop();
48}
49
50fn walk_stmt_for_bool(s: &Stmt, ctx: &mut WalkCtx) {
51 match s {
52 Stmt::LocalDecl(d) => {
53 // The LHS type used for the bool-coercion decision must
54 // reflect what the *downstream* WGSL var ends up as, not the
55 // literal HLSL token. `int mask = (c.y > 0);` is a dominant
56 // failure; `replace_types` later rewrites `int` → `f32`, but
57 // `type_from_typeref` returns `I32`, so the bool RHS wrap was
58 // skipped. Pretend `int` is already `f32` (same widening rule
59 // that `replace_types` applies to `int`/`half*`/`double*`).
60 let target_ty = decl_target_float_kind(&d.ty);
61 ctx.declare(&d.name, target_ty);
62 if let Some(init) = &d.init {
63 try_wrap_bool_rhs(init, target_ty, ctx);
64 walk_expr_for_bool_arith(init, ctx);
65 }
66 }
67 Stmt::Assign(a) => {
68 // For Assign, infer the target type from the LHS expression.
69 let target_ty = infer_type_static(&a.target, ctx);
70 try_wrap_bool_rhs(&a.value, target_ty, ctx);
71 walk_expr_for_bool(&a.target, ctx);
72 walk_expr_for_bool_arith(&a.value, ctx);
73 }
74 Stmt::Expr(e) => walk_expr_for_bool_arith(e, ctx),
75 Stmt::If(i) => {
76 walk_expr_for_bool(&i.cond, ctx);
77 walk_stmt_for_bool(&i.then_branch, ctx);
78 if let Some(e) = &i.else_branch {
79 walk_stmt_for_bool(e, ctx);
80 }
81 }
82 Stmt::While(w) => {
83 walk_expr_for_bool(&w.cond, ctx);
84 walk_stmt_for_bool(&w.body, ctx);
85 }
86 Stmt::For(f) => {
87 ctx.scope_push();
88 if let Some(init) = &f.init {
89 walk_stmt_for_bool(init, ctx);
90 }
91 if let Some(c) = &f.cond {
92 walk_expr_for_bool(c, ctx);
93 }
94 if let Some(st) = &f.step {
95 walk_expr_for_bool(st, ctx);
96 }
97 walk_stmt_for_bool(&f.body, ctx);
98 ctx.scope_pop();
99 }
100 // `return` expressions are arithmetic context — they feed into
101 // the caller's bool-or-float expectations. The corpus shape
102 // `return !mask*domain + mask*refrac_uv` (Flexi's
103 // `uv_lens_half_sphere`) needs the bool-typed `mask` operands
104 // wrapped just like a `LocalDecl` RHS would.
105 Stmt::Return(Some(e)) => walk_expr_for_bool_arith(e, ctx),
106 Stmt::Block(b) => walk_block_for_bool(b, ctx),
107 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
108 }
109}
110
111/// Top-level walk for nested sub-statements; the bool-wrap heuristic only
112/// fires on the RHS of an assignment / decl, so this just recurses without
113/// emitting edits of its own.
114#[allow(clippy::only_used_in_recursion)]
115fn walk_expr_for_bool(e: &Expr, ctx: &mut WalkCtx) {
116 match e {
117 Expr::Binary(b) => {
118 walk_expr_for_bool(&b.lhs, ctx);
119 walk_expr_for_bool(&b.rhs, ctx);
120 }
121 Expr::Unary(u) => walk_expr_for_bool(&u.operand, ctx),
122 Expr::Ternary(t) => {
123 walk_expr_for_bool(&t.cond, ctx);
124 walk_expr_for_bool(&t.then_expr, ctx);
125 walk_expr_for_bool(&t.else_expr, ctx);
126 }
127 Expr::Call(c) => {
128 for a in &c.args {
129 walk_expr_for_bool(a, ctx);
130 }
131 }
132 Expr::Member(m) => walk_expr_for_bool(&m.base, ctx),
133 Expr::Swizzle(s) => walk_expr_for_bool(&s.base, ctx),
134 Expr::Index(i) => {
135 walk_expr_for_bool(&i.base, ctx);
136 walk_expr_for_bool(&i.index, ctx);
137 }
138 Expr::InitList(l) => {
139 for e in &l.elems {
140 walk_expr_for_bool(e, ctx);
141 }
142 }
143 Expr::Assign(a) => {
144 walk_expr_for_bool(&a.target, ctx);
145 walk_expr_for_bool(&a.value, ctx);
146 }
147 Expr::Lit(_) | Expr::Ident(_, _) => {}
148 }
149}
150
151/// Walk an expression that lives in arithmetic context (the RHS of a
152/// decl/assign, an `*=` RHS, a `*`/`+`/`-`/`/` sub-tree). For every
153/// `Binary` whose operator expects numeric operands, wrap any
154/// bool-producing child (comparison or short-circuit) in `select(0.0,
155/// 1.0, …)`. Targets a dominant validate-stage failure
156/// (`InvalidBinaryOperandTypes`), exemplified by the MD2 gating idiom
157/// `value * (rs.z > 0) * (rs.z < hlim)`.
158///
159/// We don't try to peek through `Member`/`Swizzle`/`Index`/Call args
160/// here — those go through `try_wrap_bool_rhs` and the call-site coerce
161/// pass at function granularity. Conservative wrapping keeps the edits
162/// strictly zero-length insertions so existing inner edits (from
163/// `rewrite_binary_vec_mismatches`, etc.) don't fight us.
164fn walk_expr_for_bool_arith(e: &Expr, ctx: &mut WalkCtx) {
165 if let Expr::Binary(b) = e
166 && is_arith_op(b.op)
167 {
168 try_wrap_bool_arith_operand(&b.lhs, ctx);
169 try_wrap_bool_arith_operand(&b.rhs, ctx);
170 }
171 walk_expr_for_bool(e, ctx);
172 // Recurse into common arithmetic structures so multi-level
173 // `a + (b > c) * d` chains catch every bool operand.
174 match e {
175 Expr::Binary(b) => {
176 walk_expr_for_bool_arith(&b.lhs, ctx);
177 walk_expr_for_bool_arith(&b.rhs, ctx);
178 }
179 Expr::Unary(u) => {
180 // `-<bool>` and `+<bool>` are invalid in WGSL (naga's
181 // `InvalidUnaryOperandType(Negate, …)` — the dominant fs_main
182 // validation sub-bucket on the post-bb4754a 2000-sample,
183 // ≈18 presets). Treat the operand as arithmetic context so
184 // a comparison inside (`-(lum(ret)<sin(time*90))*…`) gets
185 // wrapped in `select(0.0, 1.0, …)` *before* the unary
186 // operator applies — yielding the valid `-(select(…))*…`.
187 // `!<bool>` stays handled by the parent's
188 // `try_wrap_bool_arith_operand` since `is_boolean_producing`
189 // accepts `Unary(Not, bool)` as a whole.
190 if matches!(u.op, UnaryOp::Neg | UnaryOp::Pos) {
191 try_wrap_bool_arith_operand(&u.operand, ctx);
192 }
193 walk_expr_for_bool_arith(&u.operand, ctx);
194 }
195 Expr::Ternary(t) => {
196 walk_expr_for_bool_arith(&t.then_expr, ctx);
197 walk_expr_for_bool_arith(&t.else_expr, ctx);
198 }
199 Expr::Call(c) => {
200 for a in &c.args {
201 walk_expr_for_bool_arith(a, ctx);
202 }
203 }
204 _ => {}
205 }
206}
207
208fn is_arith_op(op: BinaryOp) -> bool {
209 matches!(
210 op,
211 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Rem
212 )
213}
214
215/// Emit a `select(0.0, 1.0, (<expr>))` zero-length-insertion wrap around
216/// `e` if its top-level operator is a boolean producer. Mirrors
217/// [`try_wrap_bool_rhs`] but is called from arithmetic context, and also
218/// emits the vector-shape form
219/// `select(vec3<f32>(0.0), vec3<f32>(1.0), (vec_cmp))` when the bool's
220/// operands are vector-typed, since WGSL `select(scalar, scalar,
221/// vecN<bool>)` is invalid.
222fn try_wrap_bool_arith_operand(e: &Expr, ctx: &mut WalkCtx) {
223 if !is_boolean_producing(e, ctx) {
224 return;
225 }
226 let text = &ctx.src[e.span().start as usize..e.span().end as usize];
227 if looks_like_wgsl_generic(text) {
228 return;
229 }
230 if is_chained_comparison(e) {
231 return;
232 }
233 // Skip if already wrapped in `select(…)` — re-running the pass on its
234 // own output should be idempotent. We detect by checking the byte
235 // immediately before the start span: a preceding `(` of a `select(`
236 // call form is the marker.
237 let start = e.span().start as usize;
238 if start >= 8 {
239 let preceding = &ctx.src[start.saturating_sub(8)..start];
240 if preceding.contains("1.0, ") || preceding.contains("1.0,(") {
241 return;
242 }
243 }
244 let (zero, one) = bool_shape_select_literals(e, ctx);
245 let span = e.span();
246 ctx.edits.push(TextEdit {
247 start: span.start,
248 end: span.start,
249 replacement: format!("select({zero}, {one}, ("),
250 });
251 ctx.edits.push(TextEdit {
252 start: span.end,
253 end: span.end,
254 replacement: "))".to_string(),
255 });
256}
257
258/// Pick the `(zero, one)` literals for a `select(0.0/vecN, …)` wrap based
259/// on the *shape* of the bool-producing expression. A scalar comparison
260/// (e.g. `rs.z > 0`) keeps the scalar `0.0`/`1.0`; a vector comparison
261/// (e.g. `noise >= 0.1`) needs `vecN<f32>(0.0)`/`vecN<f32>(1.0)` so the
262/// resulting `select(...)` typechecks under WGSL's strict-shape rules.
263fn bool_shape_select_literals(e: &Expr, ctx: &mut WalkCtx) -> (&'static str, &'static str) {
264 let shape = infer_bool_shape(e, ctx);
265 match shape {
266 WgslType::Vec2F => ("vec2<f32>(0.0)", "vec2<f32>(1.0)"),
267 WgslType::Vec3F => ("vec3<f32>(0.0)", "vec3<f32>(1.0)"),
268 WgslType::Vec4F => ("vec4<f32>(0.0)", "vec4<f32>(1.0)"),
269 _ => ("0.0", "1.0"),
270 }
271}
272
273/// Walks the bool-producing subtree and returns the inferred *vector*
274/// shape of the underlying values being compared. Returns `F32` for
275/// scalar bool (so the caller emits scalar literals), `VecN` for vector
276/// bool, `Unknown` when nothing can be inferred (caller falls back to
277/// scalar — which matches the conservative pre-AA behaviour).
278fn infer_bool_shape(e: &Expr, ctx: &mut WalkCtx) -> WgslType {
279 match e {
280 Expr::Binary(b) => match b.op {
281 BinaryOp::Eq
282 | BinaryOp::Ne
283 | BinaryOp::Lt
284 | BinaryOp::Le
285 | BinaryOp::Gt
286 | BinaryOp::Ge => widen_type(
287 infer_type_static(&b.lhs, ctx),
288 infer_type_static(&b.rhs, ctx),
289 ),
290 BinaryOp::And | BinaryOp::Or => {
291 widen_type(infer_bool_shape(&b.lhs, ctx), infer_bool_shape(&b.rhs, ctx))
292 }
293 _ => WgslType::Unknown,
294 },
295 Expr::Unary(u) if matches!(u.op, UnaryOp::Not) => infer_bool_shape(&u.operand, ctx),
296 Expr::Swizzle(s) => match s.components.len() {
297 2 => WgslType::Vec2F,
298 3 => WgslType::Vec3F,
299 4 => WgslType::Vec4F,
300 _ => WgslType::F32,
301 },
302 _ => WgslType::Unknown,
303 }
304}
305
306/// Maps an HLSL declaration's textual type to the WGSL type the
307/// downstream `replace_types` pass will produce — for the purpose of
308/// deciding whether the RHS should be bool-coerced. Treats `int`, `half`,
309/// `double` and their vector siblings as float-kind, mirroring the
310/// substitution rules in [`crate::replace_types`].
311pub(super) fn decl_target_float_kind(t: &TypeRef) -> WgslType {
312 match t.name.as_str() {
313 "int" | "half" | "double" | "half1" => WgslType::F32,
314 "half2" | "double2" => WgslType::Vec2F,
315 "half3" | "double3" => WgslType::Vec3F,
316 "half4" | "double4" => WgslType::Vec4F,
317 _ => type_from_typeref(t),
318 }
319}
320
321/// If `rhs` is a comparison / boolean operator and `target_ty` is a
322/// numeric float type, emit two zero-length edits to wrap the RHS in
323/// `select(<zero>, <one>, <rhs>)`. Zero-length insertions don't collide
324/// with inner edits emitted by other passes.
325fn try_wrap_bool_rhs(rhs: &Expr, target_ty: WgslType, ctx: &mut WalkCtx) {
326 if !is_float_kind(target_ty) {
327 return;
328 }
329 if !is_boolean_producing(rhs, ctx) {
330 return;
331 }
332 // Conservative guard against false positives where the HLSL parser
333 // misinterprets a WGSL-style generic (`vec3<f32>(…)`, `array<f32, 4>`)
334 // as a chained comparison `vec3 < f32 > (…)`. The `<…>` brackets show
335 // up textually; if we see one in the RHS slice, this is not a real
336 // boolean expression — bail.
337 let rhs_text = &ctx.src[rhs.span().start as usize..rhs.span().end as usize];
338 if looks_like_wgsl_generic(rhs_text) {
339 return;
340 }
341 // Skip chained comparisons (`a < b > c`) — almost always parser noise
342 // from generic-type lookalikes, never a real preset idiom.
343 if is_chained_comparison(rhs) {
344 return;
345 }
346 let (zero, one) = match target_ty {
347 WgslType::F32 => ("0.0", "1.0"),
348 WgslType::Vec2F => ("vec2<f32>(0.0)", "vec2<f32>(1.0)"),
349 WgslType::Vec3F => ("vec3<f32>(0.0)", "vec3<f32>(1.0)"),
350 WgslType::Vec4F => ("vec4<f32>(0.0)", "vec4<f32>(1.0)"),
351 _ => return,
352 };
353
354 // WGSL `select` rejects `select(scalar, scalar, vec<bool>)` — the
355 // condition's shape must match the value shape, or the values must
356 // be scalar AND the condition scalar. Detect the "scalar target /
357 // vector cmp" mismatch (HLSL silently extracts component 0 when
358 // assigning `float x = vec_cmp;`) and wrap the cmp in `(cmp).x` to
359 // recover scalar bool.
360 let cmp_shape = infer_bool_shape(rhs, ctx);
361 let needs_scalar_extract = matches!(target_ty, WgslType::F32)
362 && matches!(
363 cmp_shape,
364 WgslType::Vec2F | WgslType::Vec3F | WgslType::Vec4F
365 );
366
367 let span = rhs.span();
368 if needs_scalar_extract {
369 // Result form: `select(0.0, 1.0, ((cmp)).x)`.
370 ctx.edits.push(TextEdit {
371 start: span.start,
372 end: span.start,
373 replacement: format!("select({zero}, {one}, (("),
374 });
375 ctx.edits.push(TextEdit {
376 start: span.end,
377 end: span.end,
378 replacement: ")).x)".to_string(),
379 });
380 } else {
381 ctx.edits.push(TextEdit {
382 start: span.start,
383 end: span.start,
384 replacement: format!("select({zero}, {one}, ("),
385 });
386 ctx.edits.push(TextEdit {
387 start: span.end,
388 end: span.end,
389 replacement: "))".to_string(),
390 });
391 }
392}
393
394fn looks_like_wgsl_generic(text: &str) -> bool {
395 text.contains("<f32>")
396 || text.contains("<i32>")
397 || text.contains("<u32>")
398 || text.contains("<bool>")
399}
400
401fn is_chained_comparison(e: &Expr) -> bool {
402 let Expr::Binary(b) = e else { return false };
403 let is_cmp = |op: BinaryOp| {
404 matches!(
405 op,
406 BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge
407 )
408 };
409 if !is_cmp(b.op) {
410 return false;
411 }
412 let inner = |x: &Expr| matches!(x, Expr::Binary(c) if is_cmp(c.op));
413 inner(&b.lhs) || inner(&b.rhs)
414}
415
416fn is_float_kind(t: WgslType) -> bool {
417 matches!(
418 t,
419 WgslType::F32 | WgslType::Vec2F | WgslType::Vec3F | WgslType::Vec4F
420 )
421}
422
423/// True when the top-level operator of `e` is one that produces a bool
424/// value in WGSL (comparison, short-circuit boolean, or a swizzle over
425/// one). We follow `Swizzle` so that `(vec >= 0).xyz` reaches this
426/// function during a `vec3` assignment and the bool-vec3 result still
427/// gets the `select(...)` wrap.
428///
429/// The `ctx` argument lets us also catch the HLSL idiom of using a
430/// bool-typed *variable* in arithmetic context (corpus shape:
431/// `bool mask = …; return !mask*domain + mask*refrac_uv;`). HLSL
432/// silently coerces those bool operands to 0.0 / 1.0; WGSL refuses
433/// the multiplication outright, so the wrap has to apply at the
434/// identifier level too.
435fn is_boolean_producing(e: &Expr, ctx: &mut WalkCtx) -> bool {
436 match e {
437 Expr::Binary(b) => matches!(
438 b.op,
439 BinaryOp::Eq
440 | BinaryOp::Ne
441 | BinaryOp::Lt
442 | BinaryOp::Le
443 | BinaryOp::Gt
444 | BinaryOp::Ge
445 | BinaryOp::And
446 | BinaryOp::Or
447 ),
448 Expr::Unary(u) => matches!(u.op, UnaryOp::Not) && is_boolean_producing(&u.operand, ctx),
449 Expr::Swizzle(s) => is_boolean_producing(&s.base, ctx),
450 Expr::Ident(name, _) => matches!(ctx.lookup(name), WgslType::Bool),
451 _ => false,
452 }
453}