onedrop_hlsl/rewrite/
bool_arith.rs

1//! Pass: bool RHS in f32 context → `select(0.0, 1.0, cond)`.
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass: bool RHS in f32 context → `select(0.0, 1.0, cond)`
7// ---------------------------------------------------------------------------
8
9/// HLSL coerces bool → float silently (`float mask = uv.x > 0.5;` yields
10/// 0.0/1.0). WGSL refuses with `the type of 'mask' is expected to be
11/// 'f32'; but got 'bool'` (or `cannot convert elements of vec3<bool> to
12/// f32` for vector comparisons). One of the largest comp failure
13/// clusters in the corpus.
14///
15/// Strategy: walk `LocalDecl`/`Assign` whose target type resolves to
16/// `f32` / `vec_n<f32>`; when the RHS is a comparison or short-circuit
17/// boolean, wrap it in `select(0.0, 1.0, <cond>)` (or `select(vec3<f32>(0.0),
18/// vec3<f32>(1.0), <cond>)` for vector). Stays conservative: only fires when
19/// both the target type and the RHS top-level operator are unambiguous.
20pub(crate) fn rewrite_bool_to_float(src: &str) -> String {
21    let Ok(tu) = parse_hlsl(src) else {
22        return src.to_string();
23    };
24    let mut ctx = WalkCtx::new(src);
25    ctx.seed_globals(&tu);
26    if let Some(body) = &tu.shader_body {
27        walk_block_for_bool(body, &mut ctx);
28    }
29    for item in &tu.items {
30        if let Item::Function(f) = item {
31            ctx.scope_push();
32            for p in &f.params {
33                ctx.declare(&p.name, type_from_typeref(&p.ty));
34            }
35            walk_block_for_bool(&f.body, &mut ctx);
36            ctx.scope_pop();
37        }
38    }
39    apply_edits(src, &mut ctx.edits)
40}
41
42fn walk_block_for_bool(b: &Block, ctx: &mut WalkCtx) {
43    ctx.scope_push();
44    for s in &b.stmts {
45        walk_stmt_for_bool(s, ctx);
46    }
47    ctx.scope_pop();
48}
49
50fn walk_stmt_for_bool(s: &Stmt, ctx: &mut WalkCtx) {
51    match s {
52        Stmt::LocalDecl(d) => {
53            // The LHS type used for the bool-coercion decision must
54            // reflect what the *downstream* WGSL var ends up as, not the
55            // literal HLSL token. `int mask = (c.y > 0);` is a dominant
56            // failure; `replace_types` later rewrites `int` → `f32`, but
57            // `type_from_typeref` returns `I32`, so the bool RHS wrap was
58            // skipped. Pretend `int` is already `f32` (same widening rule
59            // that `replace_types` applies to `int`/`half*`/`double*`).
60            let target_ty = decl_target_float_kind(&d.ty);
61            ctx.declare(&d.name, target_ty);
62            if let Some(init) = &d.init {
63                try_wrap_bool_rhs(init, target_ty, ctx);
64                walk_expr_for_bool_arith(init, ctx);
65            }
66        }
67        Stmt::Assign(a) => {
68            // For Assign, infer the target type from the LHS expression.
69            let target_ty = infer_type_static(&a.target, ctx);
70            try_wrap_bool_rhs(&a.value, target_ty, ctx);
71            walk_expr_for_bool(&a.target, ctx);
72            walk_expr_for_bool_arith(&a.value, ctx);
73        }
74        Stmt::Expr(e) => walk_expr_for_bool_arith(e, ctx),
75        Stmt::If(i) => {
76            walk_expr_for_bool(&i.cond, ctx);
77            walk_stmt_for_bool(&i.then_branch, ctx);
78            if let Some(e) = &i.else_branch {
79                walk_stmt_for_bool(e, ctx);
80            }
81        }
82        Stmt::While(w) => {
83            walk_expr_for_bool(&w.cond, ctx);
84            walk_stmt_for_bool(&w.body, ctx);
85        }
86        Stmt::For(f) => {
87            ctx.scope_push();
88            if let Some(init) = &f.init {
89                walk_stmt_for_bool(init, ctx);
90            }
91            if let Some(c) = &f.cond {
92                walk_expr_for_bool(c, ctx);
93            }
94            if let Some(st) = &f.step {
95                walk_expr_for_bool(st, ctx);
96            }
97            walk_stmt_for_bool(&f.body, ctx);
98            ctx.scope_pop();
99        }
100        // `return` expressions are arithmetic context — they feed into
101        // the caller's bool-or-float expectations. The corpus shape
102        // `return !mask*domain + mask*refrac_uv` (Flexi's
103        // `uv_lens_half_sphere`) needs the bool-typed `mask` operands
104        // wrapped just like a `LocalDecl` RHS would.
105        Stmt::Return(Some(e)) => walk_expr_for_bool_arith(e, ctx),
106        Stmt::Block(b) => walk_block_for_bool(b, ctx),
107        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
108    }
109}
110
111/// Top-level walk for nested sub-statements; the bool-wrap heuristic only
112/// fires on the RHS of an assignment / decl, so this just recurses without
113/// emitting edits of its own.
114#[allow(clippy::only_used_in_recursion)]
115fn walk_expr_for_bool(e: &Expr, ctx: &mut WalkCtx) {
116    match e {
117        Expr::Binary(b) => {
118            walk_expr_for_bool(&b.lhs, ctx);
119            walk_expr_for_bool(&b.rhs, ctx);
120        }
121        Expr::Unary(u) => walk_expr_for_bool(&u.operand, ctx),
122        Expr::Ternary(t) => {
123            walk_expr_for_bool(&t.cond, ctx);
124            walk_expr_for_bool(&t.then_expr, ctx);
125            walk_expr_for_bool(&t.else_expr, ctx);
126        }
127        Expr::Call(c) => {
128            for a in &c.args {
129                walk_expr_for_bool(a, ctx);
130            }
131        }
132        Expr::Member(m) => walk_expr_for_bool(&m.base, ctx),
133        Expr::Swizzle(s) => walk_expr_for_bool(&s.base, ctx),
134        Expr::Index(i) => {
135            walk_expr_for_bool(&i.base, ctx);
136            walk_expr_for_bool(&i.index, ctx);
137        }
138        Expr::InitList(l) => {
139            for e in &l.elems {
140                walk_expr_for_bool(e, ctx);
141            }
142        }
143        Expr::Assign(a) => {
144            walk_expr_for_bool(&a.target, ctx);
145            walk_expr_for_bool(&a.value, ctx);
146        }
147        Expr::Lit(_) | Expr::Ident(_, _) => {}
148    }
149}
150
151/// Walk an expression that lives in arithmetic context (the RHS of a
152/// decl/assign, an `*=` RHS, a `*`/`+`/`-`/`/` sub-tree). For every
153/// `Binary` whose operator expects numeric operands, wrap any
154/// bool-producing child (comparison or short-circuit) in `select(0.0,
155/// 1.0, …)`. Targets a dominant validate-stage failure
156/// (`InvalidBinaryOperandTypes`), exemplified by the MD2 gating idiom
157/// `value * (rs.z > 0) * (rs.z < hlim)`.
158///
159/// We don't try to peek through `Member`/`Swizzle`/`Index`/Call args
160/// here — those go through `try_wrap_bool_rhs` and the call-site coerce
161/// pass at function granularity. Conservative wrapping keeps the edits
162/// strictly zero-length insertions so existing inner edits (from
163/// `rewrite_binary_vec_mismatches`, etc.) don't fight us.
164fn walk_expr_for_bool_arith(e: &Expr, ctx: &mut WalkCtx) {
165    if let Expr::Binary(b) = e
166        && is_arith_op(b.op)
167    {
168        try_wrap_bool_arith_operand(&b.lhs, ctx);
169        try_wrap_bool_arith_operand(&b.rhs, ctx);
170    }
171    walk_expr_for_bool(e, ctx);
172    // Recurse into common arithmetic structures so multi-level
173    // `a + (b > c) * d` chains catch every bool operand.
174    match e {
175        Expr::Binary(b) => {
176            walk_expr_for_bool_arith(&b.lhs, ctx);
177            walk_expr_for_bool_arith(&b.rhs, ctx);
178        }
179        Expr::Unary(u) => {
180            // `-<bool>` and `+<bool>` are invalid in WGSL (naga's
181            // `InvalidUnaryOperandType(Negate, …)` — the dominant fs_main
182            // validation sub-bucket on the post-bb4754a 2000-sample,
183            // ≈18 presets). Treat the operand as arithmetic context so
184            // a comparison inside (`-(lum(ret)<sin(time*90))*…`) gets
185            // wrapped in `select(0.0, 1.0, …)` *before* the unary
186            // operator applies — yielding the valid `-(select(…))*…`.
187            // `!<bool>` stays handled by the parent's
188            // `try_wrap_bool_arith_operand` since `is_boolean_producing`
189            // accepts `Unary(Not, bool)` as a whole.
190            if matches!(u.op, UnaryOp::Neg | UnaryOp::Pos) {
191                try_wrap_bool_arith_operand(&u.operand, ctx);
192            }
193            walk_expr_for_bool_arith(&u.operand, ctx);
194        }
195        Expr::Ternary(t) => {
196            walk_expr_for_bool_arith(&t.then_expr, ctx);
197            walk_expr_for_bool_arith(&t.else_expr, ctx);
198        }
199        Expr::Call(c) => {
200            for a in &c.args {
201                walk_expr_for_bool_arith(a, ctx);
202            }
203        }
204        _ => {}
205    }
206}
207
208fn is_arith_op(op: BinaryOp) -> bool {
209    matches!(
210        op,
211        BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Rem
212    )
213}
214
215/// Emit a `select(0.0, 1.0, (<expr>))` zero-length-insertion wrap around
216/// `e` if its top-level operator is a boolean producer. Mirrors
217/// [`try_wrap_bool_rhs`] but is called from arithmetic context, and also
218/// emits the vector-shape form
219/// `select(vec3<f32>(0.0), vec3<f32>(1.0), (vec_cmp))` when the bool's
220/// operands are vector-typed, since WGSL `select(scalar, scalar,
221/// vecN<bool>)` is invalid.
222fn try_wrap_bool_arith_operand(e: &Expr, ctx: &mut WalkCtx) {
223    if !is_boolean_producing(e, ctx) {
224        return;
225    }
226    let text = &ctx.src[e.span().start as usize..e.span().end as usize];
227    if looks_like_wgsl_generic(text) {
228        return;
229    }
230    if is_chained_comparison(e) {
231        return;
232    }
233    // Skip if already wrapped in `select(…)` — re-running the pass on its
234    // own output should be idempotent. We detect by checking the byte
235    // immediately before the start span: a preceding `(` of a `select(`
236    // call form is the marker.
237    let start = e.span().start as usize;
238    if start >= 8 {
239        let preceding = &ctx.src[start.saturating_sub(8)..start];
240        if preceding.contains("1.0, ") || preceding.contains("1.0,(") {
241            return;
242        }
243    }
244    let (zero, one) = bool_shape_select_literals(e, ctx);
245    let span = e.span();
246    ctx.edits.push(TextEdit {
247        start: span.start,
248        end: span.start,
249        replacement: format!("select({zero}, {one}, ("),
250    });
251    ctx.edits.push(TextEdit {
252        start: span.end,
253        end: span.end,
254        replacement: "))".to_string(),
255    });
256}
257
258/// Pick the `(zero, one)` literals for a `select(0.0/vecN, …)` wrap based
259/// on the *shape* of the bool-producing expression. A scalar comparison
260/// (e.g. `rs.z > 0`) keeps the scalar `0.0`/`1.0`; a vector comparison
261/// (e.g. `noise >= 0.1`) needs `vecN<f32>(0.0)`/`vecN<f32>(1.0)` so the
262/// resulting `select(...)` typechecks under WGSL's strict-shape rules.
263fn bool_shape_select_literals(e: &Expr, ctx: &mut WalkCtx) -> (&'static str, &'static str) {
264    let shape = infer_bool_shape(e, ctx);
265    match shape {
266        WgslType::Vec2F => ("vec2<f32>(0.0)", "vec2<f32>(1.0)"),
267        WgslType::Vec3F => ("vec3<f32>(0.0)", "vec3<f32>(1.0)"),
268        WgslType::Vec4F => ("vec4<f32>(0.0)", "vec4<f32>(1.0)"),
269        _ => ("0.0", "1.0"),
270    }
271}
272
273/// Walks the bool-producing subtree and returns the inferred *vector*
274/// shape of the underlying values being compared. Returns `F32` for
275/// scalar bool (so the caller emits scalar literals), `VecN` for vector
276/// bool, `Unknown` when nothing can be inferred (caller falls back to
277/// scalar — which matches the conservative pre-AA behaviour).
278fn infer_bool_shape(e: &Expr, ctx: &mut WalkCtx) -> WgslType {
279    match e {
280        Expr::Binary(b) => match b.op {
281            BinaryOp::Eq
282            | BinaryOp::Ne
283            | BinaryOp::Lt
284            | BinaryOp::Le
285            | BinaryOp::Gt
286            | BinaryOp::Ge => widen_type(
287                infer_type_static(&b.lhs, ctx),
288                infer_type_static(&b.rhs, ctx),
289            ),
290            BinaryOp::And | BinaryOp::Or => {
291                widen_type(infer_bool_shape(&b.lhs, ctx), infer_bool_shape(&b.rhs, ctx))
292            }
293            _ => WgslType::Unknown,
294        },
295        Expr::Unary(u) if matches!(u.op, UnaryOp::Not) => infer_bool_shape(&u.operand, ctx),
296        Expr::Swizzle(s) => match s.components.len() {
297            2 => WgslType::Vec2F,
298            3 => WgslType::Vec3F,
299            4 => WgslType::Vec4F,
300            _ => WgslType::F32,
301        },
302        _ => WgslType::Unknown,
303    }
304}
305
306/// Maps an HLSL declaration's textual type to the WGSL type the
307/// downstream `replace_types` pass will produce — for the purpose of
308/// deciding whether the RHS should be bool-coerced. Treats `int`, `half`,
309/// `double` and their vector siblings as float-kind, mirroring the
310/// substitution rules in [`crate::replace_types`].
311pub(super) fn decl_target_float_kind(t: &TypeRef) -> WgslType {
312    match t.name.as_str() {
313        "int" | "half" | "double" | "half1" => WgslType::F32,
314        "half2" | "double2" => WgslType::Vec2F,
315        "half3" | "double3" => WgslType::Vec3F,
316        "half4" | "double4" => WgslType::Vec4F,
317        _ => type_from_typeref(t),
318    }
319}
320
321/// If `rhs` is a comparison / boolean operator and `target_ty` is a
322/// numeric float type, emit two zero-length edits to wrap the RHS in
323/// `select(<zero>, <one>, <rhs>)`. Zero-length insertions don't collide
324/// with inner edits emitted by other passes.
325fn try_wrap_bool_rhs(rhs: &Expr, target_ty: WgslType, ctx: &mut WalkCtx) {
326    if !is_float_kind(target_ty) {
327        return;
328    }
329    if !is_boolean_producing(rhs, ctx) {
330        return;
331    }
332    // Conservative guard against false positives where the HLSL parser
333    // misinterprets a WGSL-style generic (`vec3<f32>(…)`, `array<f32, 4>`)
334    // as a chained comparison `vec3 < f32 > (…)`. The `<…>` brackets show
335    // up textually; if we see one in the RHS slice, this is not a real
336    // boolean expression — bail.
337    let rhs_text = &ctx.src[rhs.span().start as usize..rhs.span().end as usize];
338    if looks_like_wgsl_generic(rhs_text) {
339        return;
340    }
341    // Skip chained comparisons (`a < b > c`) — almost always parser noise
342    // from generic-type lookalikes, never a real preset idiom.
343    if is_chained_comparison(rhs) {
344        return;
345    }
346    let (zero, one) = match target_ty {
347        WgslType::F32 => ("0.0", "1.0"),
348        WgslType::Vec2F => ("vec2<f32>(0.0)", "vec2<f32>(1.0)"),
349        WgslType::Vec3F => ("vec3<f32>(0.0)", "vec3<f32>(1.0)"),
350        WgslType::Vec4F => ("vec4<f32>(0.0)", "vec4<f32>(1.0)"),
351        _ => return,
352    };
353
354    // WGSL `select` rejects `select(scalar, scalar, vec<bool>)` — the
355    // condition's shape must match the value shape, or the values must
356    // be scalar AND the condition scalar. Detect the "scalar target /
357    // vector cmp" mismatch (HLSL silently extracts component 0 when
358    // assigning `float x = vec_cmp;`) and wrap the cmp in `(cmp).x` to
359    // recover scalar bool.
360    let cmp_shape = infer_bool_shape(rhs, ctx);
361    let needs_scalar_extract = matches!(target_ty, WgslType::F32)
362        && matches!(
363            cmp_shape,
364            WgslType::Vec2F | WgslType::Vec3F | WgslType::Vec4F
365        );
366
367    let span = rhs.span();
368    if needs_scalar_extract {
369        // Result form: `select(0.0, 1.0, ((cmp)).x)`.
370        ctx.edits.push(TextEdit {
371            start: span.start,
372            end: span.start,
373            replacement: format!("select({zero}, {one}, (("),
374        });
375        ctx.edits.push(TextEdit {
376            start: span.end,
377            end: span.end,
378            replacement: ")).x)".to_string(),
379        });
380    } else {
381        ctx.edits.push(TextEdit {
382            start: span.start,
383            end: span.start,
384            replacement: format!("select({zero}, {one}, ("),
385        });
386        ctx.edits.push(TextEdit {
387            start: span.end,
388            end: span.end,
389            replacement: "))".to_string(),
390        });
391    }
392}
393
394fn looks_like_wgsl_generic(text: &str) -> bool {
395    text.contains("<f32>")
396        || text.contains("<i32>")
397        || text.contains("<u32>")
398        || text.contains("<bool>")
399}
400
401fn is_chained_comparison(e: &Expr) -> bool {
402    let Expr::Binary(b) = e else { return false };
403    let is_cmp = |op: BinaryOp| {
404        matches!(
405            op,
406            BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge
407        )
408    };
409    if !is_cmp(b.op) {
410        return false;
411    }
412    let inner = |x: &Expr| matches!(x, Expr::Binary(c) if is_cmp(c.op));
413    inner(&b.lhs) || inner(&b.rhs)
414}
415
416fn is_float_kind(t: WgslType) -> bool {
417    matches!(
418        t,
419        WgslType::F32 | WgslType::Vec2F | WgslType::Vec3F | WgslType::Vec4F
420    )
421}
422
423/// True when the top-level operator of `e` is one that produces a bool
424/// value in WGSL (comparison, short-circuit boolean, or a swizzle over
425/// one). We follow `Swizzle` so that `(vec >= 0).xyz` reaches this
426/// function during a `vec3` assignment and the bool-vec3 result still
427/// gets the `select(...)` wrap.
428///
429/// The `ctx` argument lets us also catch the HLSL idiom of using a
430/// bool-typed *variable* in arithmetic context (corpus shape:
431/// `bool mask = …; return !mask*domain + mask*refrac_uv;`). HLSL
432/// silently coerces those bool operands to 0.0 / 1.0; WGSL refuses
433/// the multiplication outright, so the wrap has to apply at the
434/// identifier level too.
435fn is_boolean_producing(e: &Expr, ctx: &mut WalkCtx) -> bool {
436    match e {
437        Expr::Binary(b) => matches!(
438            b.op,
439            BinaryOp::Eq
440                | BinaryOp::Ne
441                | BinaryOp::Lt
442                | BinaryOp::Le
443                | BinaryOp::Gt
444                | BinaryOp::Ge
445                | BinaryOp::And
446                | BinaryOp::Or
447        ),
448        Expr::Unary(u) => matches!(u.op, UnaryOp::Not) && is_boolean_producing(&u.operand, ctx),
449        Expr::Swizzle(s) => is_boolean_producing(&s.base, ctx),
450        Expr::Ident(name, _) => matches!(ctx.lookup(name), WgslType::Bool),
451        _ => false,
452    }
453}