onedrop_hlsl/rewrite/
user_fn.rs

1//! Pass: user-defined function call-site arg coercion.
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass: user-fn call-site arg coercion
7// ---------------------------------------------------------------------------
8
9/// Collect every top-level user-defined function in `src`, then walk every
10/// call site to those names and coerce mismatched arg types. MD2 author
11/// sloppiness picked up here: passing `lavcol(ret*2)` where `ret*2` is
12/// vec3 but `lavcol(float t)` expects scalar (the EVET preset). HLSL
13/// silently truncated; WGSL refuses with InvalidCall.
14///
15/// Coercion rules per arg:
16/// - vec_n → f32: truncate via `.x` (single-component swizzle).
17/// - f32 → vec_n: broadcast via `vec_n<f32>(arg)`.
18/// - vec_n → vec_m where n > m: truncate via `.xy`/`.xyz`.
19/// - vec_n → vec_m where n < m: pad via `vec_m<f32>(arg, 0.0, …)`.
20/// - any other mismatch (Unknown, mat*, bool): pass through.
21///
22/// Conservative: only fires when the param type is a known scalar/vec
23/// and the inferred arg type is a known scalar/vec. Anything ambiguous
24/// stays untouched.
25pub(crate) fn coerce_user_fn_args(src: &str) -> String {
26    let Ok(tu) = parse_hlsl(src) else {
27        return src.to_string();
28    };
29    let mut sigs: std::collections::HashMap<String, Vec<WgslType>> =
30        std::collections::HashMap::new();
31    for item in &tu.items {
32        if let Item::Function(f) = item {
33            let params: Vec<WgslType> = f.params.iter().map(|p| type_from_typeref(&p.ty)).collect();
34            sigs.insert(f.name.clone(), params);
35        }
36    }
37    if sigs.is_empty() {
38        return src.to_string();
39    }
40    let mut ctx = WalkCtx::new(src);
41    ctx.seed_globals(&tu);
42    if let Some(body) = &tu.shader_body {
43        walk_block_for_user_fn(body, &mut ctx, &sigs);
44    }
45    for item in &tu.items {
46        if let Item::Function(f) = item {
47            ctx.scope_push();
48            for p in &f.params {
49                ctx.declare(&p.name, type_from_typeref(&p.ty));
50            }
51            walk_block_for_user_fn(&f.body, &mut ctx, &sigs);
52            ctx.scope_pop();
53        }
54    }
55    apply_edits(src, &mut ctx.edits)
56}
57
58fn walk_block_for_user_fn(
59    b: &Block,
60    ctx: &mut WalkCtx,
61    sigs: &std::collections::HashMap<String, Vec<WgslType>>,
62) {
63    ctx.scope_push();
64    for s in &b.stmts {
65        walk_stmt_for_user_fn(s, ctx, sigs);
66    }
67    ctx.scope_pop();
68}
69
70fn walk_stmt_for_user_fn(
71    s: &Stmt,
72    ctx: &mut WalkCtx,
73    sigs: &std::collections::HashMap<String, Vec<WgslType>>,
74) {
75    match s {
76        Stmt::LocalDecl(d) => {
77            ctx.declare(&d.name, type_from_typeref(&d.ty));
78            if let Some(init) = &d.init {
79                walk_expr_for_user_fn(init, ctx, sigs);
80            }
81        }
82        Stmt::Assign(a) => {
83            walk_expr_for_user_fn(&a.target, ctx, sigs);
84            walk_expr_for_user_fn(&a.value, ctx, sigs);
85        }
86        Stmt::Expr(e) => {
87            walk_expr_for_user_fn(e, ctx, sigs);
88        }
89        Stmt::If(i) => {
90            walk_expr_for_user_fn(&i.cond, ctx, sigs);
91            walk_stmt_for_user_fn(&i.then_branch, ctx, sigs);
92            if let Some(e) = &i.else_branch {
93                walk_stmt_for_user_fn(e, ctx, sigs);
94            }
95        }
96        Stmt::While(w) => {
97            walk_expr_for_user_fn(&w.cond, ctx, sigs);
98            walk_stmt_for_user_fn(&w.body, ctx, sigs);
99        }
100        Stmt::For(f) => {
101            ctx.scope_push();
102            if let Some(init) = &f.init {
103                walk_stmt_for_user_fn(init, ctx, sigs);
104            }
105            if let Some(c) = &f.cond {
106                walk_expr_for_user_fn(c, ctx, sigs);
107            }
108            if let Some(st) = &f.step {
109                walk_expr_for_user_fn(st, ctx, sigs);
110            }
111            walk_stmt_for_user_fn(&f.body, ctx, sigs);
112            ctx.scope_pop();
113        }
114        Stmt::Return(Some(e)) => {
115            walk_expr_for_user_fn(e, ctx, sigs);
116        }
117        Stmt::Block(b) => walk_block_for_user_fn(b, ctx, sigs),
118        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
119    }
120}
121
122fn walk_expr_for_user_fn(
123    e: &Expr,
124    ctx: &mut WalkCtx,
125    sigs: &std::collections::HashMap<String, Vec<WgslType>>,
126) -> WgslType {
127    match e {
128        Expr::Call(c) => {
129            // Recurse into args first so nested user-fn calls + inner
130            // edits run before we add our own.
131            for a in &c.args {
132                walk_expr_for_user_fn(a, ctx, sigs);
133            }
134            if let Some(params) = sigs.get(&c.callee) {
135                for (idx, arg) in c.args.iter().enumerate() {
136                    if let Some(expected) = params.get(idx).copied() {
137                        coerce_arg(arg, expected, ctx);
138                    }
139                }
140                // Return type isn't tracked here — we don't need it for
141                // the call-site walker, just emit edits and return
142                // Unknown.
143                return WgslType::Unknown;
144            }
145            if let Some(t) = constructor_return(&c.callee) {
146                return t;
147            }
148            builtin_return(&c.callee, &c.args, ctx)
149        }
150        Expr::Binary(b) => {
151            walk_expr_for_user_fn(&b.lhs, ctx, sigs);
152            walk_expr_for_user_fn(&b.rhs, ctx, sigs);
153            widen_type(infer_type(&b.lhs, ctx), infer_type(&b.rhs, ctx))
154        }
155        Expr::Unary(u) => walk_expr_for_user_fn(&u.operand, ctx, sigs),
156        Expr::Ternary(t) => {
157            walk_expr_for_user_fn(&t.cond, ctx, sigs);
158            walk_expr_for_user_fn(&t.then_expr, ctx, sigs);
159            walk_expr_for_user_fn(&t.else_expr, ctx, sigs)
160        }
161        Expr::Swizzle(s) => {
162            walk_expr_for_user_fn(&s.base, ctx, sigs);
163            infer_type(e, ctx)
164        }
165        Expr::Member(m) => {
166            walk_expr_for_user_fn(&m.base, ctx, sigs);
167            WgslType::Unknown
168        }
169        Expr::Index(i) => {
170            walk_expr_for_user_fn(&i.base, ctx, sigs);
171            walk_expr_for_user_fn(&i.index, ctx, sigs);
172            WgslType::Unknown
173        }
174        Expr::InitList(l) => {
175            for e in &l.elems {
176                walk_expr_for_user_fn(e, ctx, sigs);
177            }
178            WgslType::Unknown
179        }
180        Expr::Assign(a) => {
181            walk_expr_for_user_fn(&a.target, ctx, sigs);
182            walk_expr_for_user_fn(&a.value, ctx, sigs)
183        }
184        Expr::Ident(name, _) => ctx.lookup(name),
185        Expr::Lit(_) => WgslType::F32,
186    }
187}
188
189/// Emit an edit on `arg` that converts its inferred type to `expected`.
190/// No-op when both types agree or either is unknown / non-scalar-vec.
191pub(super) fn coerce_arg(arg: &Expr, expected: WgslType, ctx: &mut WalkCtx) {
192    let got = infer_type(arg, ctx);
193    coerce_arg_known(arg, expected, got, ctx);
194}
195
196/// Variant that takes the already-known operand type. Use this when an
197/// outer pass (binop_vec) has already type-walked the expression and may
198/// have emitted truncation edits that change the *effective* type — the
199/// stateless `infer_type` doesn't see those edits and would over-coerce.
200pub(super) fn coerce_arg_known(arg: &Expr, expected: WgslType, got: WgslType, ctx: &mut WalkCtx) {
201    if got == expected || got == WgslType::Unknown || expected == WgslType::Unknown {
202        return;
203    }
204    let span = arg.span();
205    match (got, expected) {
206        // vec → f32: truncate with `.x`.
207        (g, WgslType::F32) if g.is_vec() => ctx.emit_truncation(span, 1),
208        // f32 → vec_n: broadcast via constructor.
209        (WgslType::F32, e) if e.is_vec() => {
210            let prefix = format!("{}(", e.wgsl_name());
211            ctx.edits.push(TextEdit {
212                start: span.start,
213                end: span.start,
214                replacement: prefix,
215            });
216            ctx.edits.push(TextEdit {
217                start: span.end,
218                end: span.end,
219                replacement: ")".to_string(),
220            });
221        }
222        // vec_n → vec_m with n > m: truncate.
223        (g, e) if g.is_vec() && e.is_vec() && vec_size(g) > vec_size(e) => {
224            ctx.emit_truncation(span, vec_size(e));
225        }
226        // vec_n → vec_m with n < m: pad with zeros via constructor.
227        (g, e) if g.is_vec() && e.is_vec() && vec_size(g) < vec_size(e) => {
228            let pad = vec_size(e) - vec_size(g);
229            let zeros: Vec<&str> = (0..pad).map(|_| "0.0").collect();
230            let suffix = format!(", {})", zeros.join(", "));
231            ctx.edits.push(TextEdit {
232                start: span.start,
233                end: span.start,
234                replacement: format!("{}(", e.wgsl_name()),
235            });
236            ctx.edits.push(TextEdit {
237                start: span.end,
238                end: span.end,
239                replacement: suffix,
240            });
241        }
242        _ => {}
243    }
244}