1use super::*;
4
5pub(crate) fn coerce_user_fn_args(src: &str) -> String {
26 let Ok(tu) = parse_hlsl(src) else {
27 return src.to_string();
28 };
29 let mut sigs: std::collections::HashMap<String, Vec<WgslType>> =
30 std::collections::HashMap::new();
31 for item in &tu.items {
32 if let Item::Function(f) = item {
33 let params: Vec<WgslType> = f.params.iter().map(|p| type_from_typeref(&p.ty)).collect();
34 sigs.insert(f.name.clone(), params);
35 }
36 }
37 if sigs.is_empty() {
38 return src.to_string();
39 }
40 let mut ctx = WalkCtx::new(src);
41 ctx.seed_globals(&tu);
42 if let Some(body) = &tu.shader_body {
43 walk_block_for_user_fn(body, &mut ctx, &sigs);
44 }
45 for item in &tu.items {
46 if let Item::Function(f) = item {
47 ctx.scope_push();
48 for p in &f.params {
49 ctx.declare(&p.name, type_from_typeref(&p.ty));
50 }
51 walk_block_for_user_fn(&f.body, &mut ctx, &sigs);
52 ctx.scope_pop();
53 }
54 }
55 apply_edits(src, &mut ctx.edits)
56}
57
58fn walk_block_for_user_fn(
59 b: &Block,
60 ctx: &mut WalkCtx,
61 sigs: &std::collections::HashMap<String, Vec<WgslType>>,
62) {
63 ctx.scope_push();
64 for s in &b.stmts {
65 walk_stmt_for_user_fn(s, ctx, sigs);
66 }
67 ctx.scope_pop();
68}
69
70fn walk_stmt_for_user_fn(
71 s: &Stmt,
72 ctx: &mut WalkCtx,
73 sigs: &std::collections::HashMap<String, Vec<WgslType>>,
74) {
75 match s {
76 Stmt::LocalDecl(d) => {
77 ctx.declare(&d.name, type_from_typeref(&d.ty));
78 if let Some(init) = &d.init {
79 walk_expr_for_user_fn(init, ctx, sigs);
80 }
81 }
82 Stmt::Assign(a) => {
83 walk_expr_for_user_fn(&a.target, ctx, sigs);
84 walk_expr_for_user_fn(&a.value, ctx, sigs);
85 }
86 Stmt::Expr(e) => {
87 walk_expr_for_user_fn(e, ctx, sigs);
88 }
89 Stmt::If(i) => {
90 walk_expr_for_user_fn(&i.cond, ctx, sigs);
91 walk_stmt_for_user_fn(&i.then_branch, ctx, sigs);
92 if let Some(e) = &i.else_branch {
93 walk_stmt_for_user_fn(e, ctx, sigs);
94 }
95 }
96 Stmt::While(w) => {
97 walk_expr_for_user_fn(&w.cond, ctx, sigs);
98 walk_stmt_for_user_fn(&w.body, ctx, sigs);
99 }
100 Stmt::For(f) => {
101 ctx.scope_push();
102 if let Some(init) = &f.init {
103 walk_stmt_for_user_fn(init, ctx, sigs);
104 }
105 if let Some(c) = &f.cond {
106 walk_expr_for_user_fn(c, ctx, sigs);
107 }
108 if let Some(st) = &f.step {
109 walk_expr_for_user_fn(st, ctx, sigs);
110 }
111 walk_stmt_for_user_fn(&f.body, ctx, sigs);
112 ctx.scope_pop();
113 }
114 Stmt::Return(Some(e)) => {
115 walk_expr_for_user_fn(e, ctx, sigs);
116 }
117 Stmt::Block(b) => walk_block_for_user_fn(b, ctx, sigs),
118 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
119 }
120}
121
122fn walk_expr_for_user_fn(
123 e: &Expr,
124 ctx: &mut WalkCtx,
125 sigs: &std::collections::HashMap<String, Vec<WgslType>>,
126) -> WgslType {
127 match e {
128 Expr::Call(c) => {
129 for a in &c.args {
132 walk_expr_for_user_fn(a, ctx, sigs);
133 }
134 if let Some(params) = sigs.get(&c.callee) {
135 for (idx, arg) in c.args.iter().enumerate() {
136 if let Some(expected) = params.get(idx).copied() {
137 coerce_arg(arg, expected, ctx);
138 }
139 }
140 return WgslType::Unknown;
144 }
145 if let Some(t) = constructor_return(&c.callee) {
146 return t;
147 }
148 builtin_return(&c.callee, &c.args, ctx)
149 }
150 Expr::Binary(b) => {
151 walk_expr_for_user_fn(&b.lhs, ctx, sigs);
152 walk_expr_for_user_fn(&b.rhs, ctx, sigs);
153 widen_type(infer_type(&b.lhs, ctx), infer_type(&b.rhs, ctx))
154 }
155 Expr::Unary(u) => walk_expr_for_user_fn(&u.operand, ctx, sigs),
156 Expr::Ternary(t) => {
157 walk_expr_for_user_fn(&t.cond, ctx, sigs);
158 walk_expr_for_user_fn(&t.then_expr, ctx, sigs);
159 walk_expr_for_user_fn(&t.else_expr, ctx, sigs)
160 }
161 Expr::Swizzle(s) => {
162 walk_expr_for_user_fn(&s.base, ctx, sigs);
163 infer_type(e, ctx)
164 }
165 Expr::Member(m) => {
166 walk_expr_for_user_fn(&m.base, ctx, sigs);
167 WgslType::Unknown
168 }
169 Expr::Index(i) => {
170 walk_expr_for_user_fn(&i.base, ctx, sigs);
171 walk_expr_for_user_fn(&i.index, ctx, sigs);
172 WgslType::Unknown
173 }
174 Expr::InitList(l) => {
175 for e in &l.elems {
176 walk_expr_for_user_fn(e, ctx, sigs);
177 }
178 WgslType::Unknown
179 }
180 Expr::Assign(a) => {
181 walk_expr_for_user_fn(&a.target, ctx, sigs);
182 walk_expr_for_user_fn(&a.value, ctx, sigs)
183 }
184 Expr::Ident(name, _) => ctx.lookup(name),
185 Expr::Lit(_) => WgslType::F32,
186 }
187}
188
189pub(super) fn coerce_arg(arg: &Expr, expected: WgslType, ctx: &mut WalkCtx) {
192 let got = infer_type(arg, ctx);
193 coerce_arg_known(arg, expected, got, ctx);
194}
195
196pub(super) fn coerce_arg_known(arg: &Expr, expected: WgslType, got: WgslType, ctx: &mut WalkCtx) {
201 if got == expected || got == WgslType::Unknown || expected == WgslType::Unknown {
202 return;
203 }
204 let span = arg.span();
205 match (got, expected) {
206 (g, WgslType::F32) if g.is_vec() => ctx.emit_truncation(span, 1),
208 (WgslType::F32, e) if e.is_vec() => {
210 let prefix = format!("{}(", e.wgsl_name());
211 ctx.edits.push(TextEdit {
212 start: span.start,
213 end: span.start,
214 replacement: prefix,
215 });
216 ctx.edits.push(TextEdit {
217 start: span.end,
218 end: span.end,
219 replacement: ")".to_string(),
220 });
221 }
222 (g, e) if g.is_vec() && e.is_vec() && vec_size(g) > vec_size(e) => {
224 ctx.emit_truncation(span, vec_size(e));
225 }
226 (g, e) if g.is_vec() && e.is_vec() && vec_size(g) < vec_size(e) => {
228 let pad = vec_size(e) - vec_size(g);
229 let zeros: Vec<&str> = (0..pad).map(|_| "0.0").collect();
230 let suffix = format!(", {})", zeros.join(", "));
231 ctx.edits.push(TextEdit {
232 start: span.start,
233 end: span.start,
234 replacement: format!("{}(", e.wgsl_name()),
235 });
236 ctx.edits.push(TextEdit {
237 start: span.end,
238 end: span.end,
239 replacement: suffix,
240 });
241 }
242 _ => {}
243 }
244}