onedrop_hlsl/rewrite/
binop_vec.rs1use super::user_fn;
4use super::*;
5
6pub(crate) fn rewrite_binary_vec_mismatches(src: &str) -> String {
7 let Ok(tu) = parse_hlsl(src) else {
8 return src.to_string();
9 };
10 let mut ctx = WalkCtx::new(src);
11 ctx.seed_globals(&tu);
12 if let Some(body) = &tu.shader_body {
13 walk_block(body, &mut ctx, WgslType::Unknown);
14 }
15 for item in &tu.items {
16 if let Item::Function(f) = item {
17 ctx.scope_push();
18 for p in &f.params {
19 ctx.declare(&p.name, type_from_typeref(&p.ty));
20 }
21 walk_block(&f.body, &mut ctx, type_from_typeref(&f.return_type));
22 ctx.scope_pop();
23 }
24 }
25 apply_edits(src, &mut ctx.edits)
26}
27fn walk_block(b: &Block, ctx: &mut WalkCtx, return_ty: WgslType) {
28 ctx.scope_push();
29 for s in &b.stmts {
30 walk_stmt(s, ctx, return_ty);
31 }
32 ctx.scope_pop();
33}
34
35fn walk_stmt(s: &Stmt, ctx: &mut WalkCtx, return_ty: WgslType) {
36 match s {
37 Stmt::LocalDecl(d) => {
38 let decl_ty = type_from_typeref(&d.ty);
39 ctx.declare(&d.name, decl_ty);
40 if let Some(init) = &d.init {
41 let init_ty = walk_expr(init, ctx);
42 if d.array_len.is_none() {
52 user_fn::coerce_arg_known(init, decl_ty, init_ty, ctx);
53 }
54 }
55 if let Some(len) = &d.array_len {
56 walk_expr(len, ctx);
57 }
58 }
59 Stmt::Assign(a) => {
60 walk_expr(&a.target, ctx);
61 walk_expr(&a.value, ctx);
62 }
63 Stmt::Expr(e) => {
64 walk_expr(e, ctx);
65 }
66 Stmt::If(i) => {
67 walk_expr(&i.cond, ctx);
68 walk_stmt(&i.then_branch, ctx, return_ty);
69 if let Some(e) = &i.else_branch {
70 walk_stmt(e, ctx, return_ty);
71 }
72 }
73 Stmt::While(w) => {
74 walk_expr(&w.cond, ctx);
75 walk_stmt(&w.body, ctx, return_ty);
76 }
77 Stmt::For(f) => {
78 ctx.scope_push();
79 if let Some(init) = &f.init {
80 walk_stmt(init, ctx, return_ty);
81 }
82 if let Some(c) = &f.cond {
83 walk_expr(c, ctx);
84 }
85 if let Some(st) = &f.step {
86 walk_expr(st, ctx);
87 }
88 walk_stmt(&f.body, ctx, return_ty);
89 ctx.scope_pop();
90 }
91 Stmt::Return(Some(e)) => {
92 let val_ty = walk_expr(e, ctx);
93 if !matches!(return_ty, WgslType::Unknown) {
100 user_fn::coerce_arg_known(e, return_ty, val_ty, ctx);
101 }
102 }
103 Stmt::Block(b) => walk_block(b, ctx, return_ty),
104 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
105 }
106}
107
108fn walk_expr(e: &Expr, ctx: &mut WalkCtx) -> WgslType {
111 match e {
112 Expr::Lit(l) => match l.value {
113 LitValue::Int(_) => WgslType::F32, LitValue::Float(_) => WgslType::F32,
115 LitValue::Bool(_) => WgslType::Bool,
116 },
117 Expr::Ident(name, _) => ctx.lookup(name),
118 Expr::Binary(b) => {
119 let lt = walk_expr(&b.lhs, ctx);
120 let rt = walk_expr(&b.rhs, ctx);
121 let arith = matches!(
122 b.op,
123 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Rem
124 );
125 if arith && lt.is_vec() && rt.is_vec() {
126 let ls = vec_size(lt);
127 let rs = vec_size(rt);
128 if ls != rs {
129 let min = ls.min(rs);
130 if ls > min {
131 ctx.emit_truncation(b.lhs.span(), min);
132 } else {
133 ctx.emit_truncation(b.rhs.span(), min);
134 }
135 return vec_of_size(min);
136 }
137 }
138 if matches!(
140 b.op,
141 BinaryOp::Eq
142 | BinaryOp::Ne
143 | BinaryOp::Lt
144 | BinaryOp::Le
145 | BinaryOp::Gt
146 | BinaryOp::Ge
147 | BinaryOp::And
148 | BinaryOp::Or
149 ) {
150 return WgslType::Bool;
151 }
152 widen_type(lt, rt)
153 }
154 Expr::Unary(u) => walk_expr(&u.operand, ctx),
155 Expr::Ternary(t) => {
156 walk_expr(&t.cond, ctx);
157 let a = walk_expr(&t.then_expr, ctx);
158 let b = walk_expr(&t.else_expr, ctx);
159 widen_type(a, b)
160 }
161 Expr::Call(c) => {
162 if let Some(t) = constructor_return(&c.callee) {
166 for a in &c.args {
167 walk_expr(a, ctx);
168 }
169 return t;
170 }
171 for a in &c.args {
173 walk_expr(a, ctx);
174 }
175 builtin_return(&c.callee, &c.args, ctx)
176 }
177 Expr::Member(m) => {
178 walk_expr(&m.base, ctx);
179 WgslType::Unknown }
181 Expr::Swizzle(s) => {
182 let base = walk_expr(&s.base, ctx);
183 if base.is_vec() {
184 vec_of_size(s.components.len())
185 } else {
186 WgslType::Unknown
187 }
188 }
189 Expr::Index(i) => {
190 walk_expr(&i.base, ctx);
191 walk_expr(&i.index, ctx);
192 WgslType::Unknown
193 }
194 Expr::InitList(l) => {
195 for e in &l.elems {
196 walk_expr(e, ctx);
197 }
198 WgslType::Unknown
199 }
200 Expr::Assign(a) => {
201 walk_expr(&a.target, ctx);
202 walk_expr(&a.value, ctx)
203 }
204 }
205}