onedrop_hlsl/rewrite/
binop_vec.rs

1//! Pass: binary operator vec-size mismatch fix-up.
2
3use super::user_fn;
4use super::*;
5
6pub(crate) fn rewrite_binary_vec_mismatches(src: &str) -> String {
7    let Ok(tu) = parse_hlsl(src) else {
8        return src.to_string();
9    };
10    let mut ctx = WalkCtx::new(src);
11    ctx.seed_globals(&tu);
12    if let Some(body) = &tu.shader_body {
13        walk_block(body, &mut ctx, WgslType::Unknown);
14    }
15    for item in &tu.items {
16        if let Item::Function(f) = item {
17            ctx.scope_push();
18            for p in &f.params {
19                ctx.declare(&p.name, type_from_typeref(&p.ty));
20            }
21            walk_block(&f.body, &mut ctx, type_from_typeref(&f.return_type));
22            ctx.scope_pop();
23        }
24    }
25    apply_edits(src, &mut ctx.edits)
26}
27fn walk_block(b: &Block, ctx: &mut WalkCtx, return_ty: WgslType) {
28    ctx.scope_push();
29    for s in &b.stmts {
30        walk_stmt(s, ctx, return_ty);
31    }
32    ctx.scope_pop();
33}
34
35fn walk_stmt(s: &Stmt, ctx: &mut WalkCtx, return_ty: WgslType) {
36    match s {
37        Stmt::LocalDecl(d) => {
38            let decl_ty = type_from_typeref(&d.ty);
39            ctx.declare(&d.name, decl_ty);
40            if let Some(init) = &d.init {
41                let init_ty = walk_expr(init, ctx);
42                // HLSL silently truncates/broadcasts when the init's type
43                // doesn't match the declared type (e.g. `float k2 = (20*uv)%2;`
44                // — RHS is vec2 but the variable is scalar). WGSL refuses
45                // with `the type of \`k2\` is expected to be \`f32\`, but
46                // got \`vec2<f32>\``. Reuse the user-fn arg coercion
47                // helper, passing the type we just inferred (which already
48                // accounts for any in-flight truncation edits emitted by
49                // the binop walker on `init`'s sub-expressions — re-running
50                // `infer_type` would over-coerce already-truncated shapes).
51                if d.array_len.is_none() {
52                    user_fn::coerce_arg_known(init, decl_ty, init_ty, ctx);
53                }
54            }
55            if let Some(len) = &d.array_len {
56                walk_expr(len, ctx);
57            }
58        }
59        Stmt::Assign(a) => {
60            walk_expr(&a.target, ctx);
61            walk_expr(&a.value, ctx);
62        }
63        Stmt::Expr(e) => {
64            walk_expr(e, ctx);
65        }
66        Stmt::If(i) => {
67            walk_expr(&i.cond, ctx);
68            walk_stmt(&i.then_branch, ctx, return_ty);
69            if let Some(e) = &i.else_branch {
70                walk_stmt(e, ctx, return_ty);
71            }
72        }
73        Stmt::While(w) => {
74            walk_expr(&w.cond, ctx);
75            walk_stmt(&w.body, ctx, return_ty);
76        }
77        Stmt::For(f) => {
78            ctx.scope_push();
79            if let Some(init) = &f.init {
80                walk_stmt(init, ctx, return_ty);
81            }
82            if let Some(c) = &f.cond {
83                walk_expr(c, ctx);
84            }
85            if let Some(st) = &f.step {
86                walk_expr(st, ctx);
87            }
88            walk_stmt(&f.body, ctx, return_ty);
89            ctx.scope_pop();
90        }
91        Stmt::Return(Some(e)) => {
92            let val_ty = walk_expr(e, ctx);
93            // Same trick as the LocalDecl init: when the function declares
94            // `float foo()` but its body returns `dots*dots` (vec3), HLSL
95            // silently truncates to `.x`; WGSL refuses with `InvalidReturnType`.
96            // Reuse the user-fn arg coercion with the type we just inferred —
97            // skip when return_ty is Unknown (e.g. inside `shader_body`'s
98            // top-level statements where there's no enclosing function).
99            if !matches!(return_ty, WgslType::Unknown) {
100                user_fn::coerce_arg_known(e, return_ty, val_ty, ctx);
101            }
102        }
103        Stmt::Block(b) => walk_block(b, ctx, return_ty),
104        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
105    }
106}
107
108/// Walk an expression, return its inferred type. Emit fix-up edits along
109/// the way for sub-expressions that look broken.
110fn walk_expr(e: &Expr, ctx: &mut WalkCtx) -> WgslType {
111    match e {
112        Expr::Lit(l) => match l.value {
113            LitValue::Int(_) => WgslType::F32, // HLSL coerces int→float freely
114            LitValue::Float(_) => WgslType::F32,
115            LitValue::Bool(_) => WgslType::Bool,
116        },
117        Expr::Ident(name, _) => ctx.lookup(name),
118        Expr::Binary(b) => {
119            let lt = walk_expr(&b.lhs, ctx);
120            let rt = walk_expr(&b.rhs, ctx);
121            let arith = matches!(
122                b.op,
123                BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Rem
124            );
125            if arith && lt.is_vec() && rt.is_vec() {
126                let ls = vec_size(lt);
127                let rs = vec_size(rt);
128                if ls != rs {
129                    let min = ls.min(rs);
130                    if ls > min {
131                        ctx.emit_truncation(b.lhs.span(), min);
132                    } else {
133                        ctx.emit_truncation(b.rhs.span(), min);
134                    }
135                    return vec_of_size(min);
136                }
137            }
138            // Comparison / logical → bool, otherwise widen.
139            if matches!(
140                b.op,
141                BinaryOp::Eq
142                    | BinaryOp::Ne
143                    | BinaryOp::Lt
144                    | BinaryOp::Le
145                    | BinaryOp::Gt
146                    | BinaryOp::Ge
147                    | BinaryOp::And
148                    | BinaryOp::Or
149            ) {
150                return WgslType::Bool;
151            }
152            widen_type(lt, rt)
153        }
154        Expr::Unary(u) => walk_expr(&u.operand, ctx),
155        Expr::Ternary(t) => {
156            walk_expr(&t.cond, ctx);
157            let a = walk_expr(&t.then_expr, ctx);
158            let b = walk_expr(&t.else_expr, ctx);
159            widen_type(a, b)
160        }
161        Expr::Call(c) => {
162            // Constructor calls: `float2(…)`, `vec3<f32>(…)`, etc. Return
163            // the corresponding type directly so subsequent binops can use
164            // it without round-tripping through arg inference.
165            if let Some(t) = constructor_return(&c.callee) {
166                for a in &c.args {
167                    walk_expr(a, ctx);
168                }
169                return t;
170            }
171            // Known builtins.
172            for a in &c.args {
173                walk_expr(a, ctx);
174            }
175            builtin_return(&c.callee, &c.args, ctx)
176        }
177        Expr::Member(m) => {
178            walk_expr(&m.base, ctx);
179            WgslType::Unknown // unknown struct fields — we don't model
180        }
181        Expr::Swizzle(s) => {
182            let base = walk_expr(&s.base, ctx);
183            if base.is_vec() {
184                vec_of_size(s.components.len())
185            } else {
186                WgslType::Unknown
187            }
188        }
189        Expr::Index(i) => {
190            walk_expr(&i.base, ctx);
191            walk_expr(&i.index, ctx);
192            WgslType::Unknown
193        }
194        Expr::InitList(l) => {
195            for e in &l.elems {
196                walk_expr(e, ctx);
197            }
198            WgslType::Unknown
199        }
200        Expr::Assign(a) => {
201            walk_expr(&a.target, ctx);
202            walk_expr(&a.value, ctx)
203        }
204    }
205}