onedrop_hlsl/rewrite/
texture_uv.rs

1//! Pass: texture-sampling UV coercion (tex2D / GetPixel / GetBlur*).
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass 3: texture-sampling UV coercion
7// ---------------------------------------------------------------------------
8
9/// For every `tex2D(sampler, expr)` call where `expr` is inferred as a vec3
10/// or vec4 (not vec2), wrap `expr` in parens and append `.xy`. HLSL
11/// silently truncates non-vec2 UV args; WGSL rejects them. The regex
12/// pipeline doesn't reach into call args at this depth.
13///
14/// Also normalises `GetPixel`/`GetBlur1..3` arg #0 (which lift to
15/// `textureSample` later) and `pow` argument vec-mismatches (cluster A).
16pub(crate) fn coerce_texture_uv_args(src: &str) -> String {
17    let Ok(tu) = parse_hlsl(src) else {
18        return src.to_string();
19    };
20    let mut ctx = WalkCtx::new(src);
21    ctx.seed_globals(&tu);
22    if let Some(body) = &tu.shader_body {
23        walk_block_for_uv(body, &mut ctx);
24    }
25    for item in &tu.items {
26        if let Item::Function(f) = item {
27            ctx.scope_push();
28            for p in &f.params {
29                ctx.declare(&p.name, type_from_typeref(&p.ty));
30            }
31            walk_block_for_uv(&f.body, &mut ctx);
32            ctx.scope_pop();
33        }
34    }
35    apply_edits(src, &mut ctx.edits)
36}
37
38fn walk_block_for_uv(b: &Block, ctx: &mut WalkCtx) {
39    ctx.scope_push();
40    for s in &b.stmts {
41        walk_stmt_for_uv(s, ctx);
42    }
43    ctx.scope_pop();
44}
45
46fn walk_stmt_for_uv(s: &Stmt, ctx: &mut WalkCtx) {
47    match s {
48        Stmt::LocalDecl(d) => {
49            ctx.declare(&d.name, type_from_typeref(&d.ty));
50            if let Some(init) = &d.init {
51                walk_expr_for_uv(init, ctx);
52            }
53        }
54        Stmt::Assign(a) => {
55            walk_expr_for_uv(&a.target, ctx);
56            walk_expr_for_uv(&a.value, ctx);
57        }
58        Stmt::Expr(e) => {
59            walk_expr_for_uv(e, ctx);
60        }
61        Stmt::If(i) => {
62            walk_expr_for_uv(&i.cond, ctx);
63            walk_stmt_for_uv(&i.then_branch, ctx);
64            if let Some(e) = &i.else_branch {
65                walk_stmt_for_uv(e, ctx);
66            }
67        }
68        Stmt::While(w) => {
69            walk_expr_for_uv(&w.cond, ctx);
70            walk_stmt_for_uv(&w.body, ctx);
71        }
72        Stmt::For(f) => {
73            ctx.scope_push();
74            if let Some(init) = &f.init {
75                walk_stmt_for_uv(init, ctx);
76            }
77            if let Some(c) = &f.cond {
78                walk_expr_for_uv(c, ctx);
79            }
80            if let Some(st) = &f.step {
81                walk_expr_for_uv(st, ctx);
82            }
83            walk_stmt_for_uv(&f.body, ctx);
84            ctx.scope_pop();
85        }
86        Stmt::Return(Some(e)) => {
87            walk_expr_for_uv(e, ctx);
88        }
89        Stmt::Block(b) => walk_block_for_uv(b, ctx),
90        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
91    }
92}
93
94fn walk_expr_for_uv(e: &Expr, ctx: &mut WalkCtx) -> WgslType {
95    match e {
96        Expr::Call(c) => {
97            // Recurse first so nested calls see their args walked too.
98            for a in &c.args {
99                walk_expr_for_uv(a, ctx);
100            }
101            match c.callee.as_str() {
102                "tex2D" => {
103                    // Arg 1 is the UV — must be vec2. Truncate vec3/vec4
104                    // down to `.xy`; broadcast scalar (f32 — typically a
105                    // bare `uv.x` field access where the preset author
106                    // meant a 2D lookup) to `vec2<f32>(arg)` so naga
107                    // doesn't trip with InvalidImageCoordinateType. The
108                    // scalar-broadcast path lights up the midgitstraights
109                    // / suksma-neck pair of MD2 packs that wrote
110                    // `tex2D(sampler_noise_hq, uv.x)` verbatim.
111                    if let Some(uv_arg) = c.args.get(1) {
112                        let t = infer_type(uv_arg, ctx);
113                        if t.is_vec() && vec_size(t) > 2 {
114                            ctx.emit_truncation(uv_arg.span(), 2);
115                        } else if matches!(t, WgslType::F32 | WgslType::I32) {
116                            ctx.emit_scalar_to_vec2(uv_arg.span());
117                        }
118                    }
119                    WgslType::Vec4F
120                }
121                "GetPixel" | "GetBlur1" | "GetBlur2" | "GetBlur3" => {
122                    if let Some(uv_arg) = c.args.first() {
123                        let t = infer_type(uv_arg, ctx);
124                        if t.is_vec() && vec_size(t) > 2 {
125                            ctx.emit_truncation(uv_arg.span(), 2);
126                        } else if matches!(t, WgslType::F32 | WgslType::I32) {
127                            ctx.emit_scalar_to_vec2(uv_arg.span());
128                        }
129                    }
130                    WgslType::Vec3F
131                }
132                // `lum(c: vec3<f32>) -> f32` — the codegen helper expects
133                // vec3 exactly. Truncate vec4 args down to `.xyz` for the
134                // `lum(noise)` family. For the vec2-arg case found in the
135                // shifter + Isosceles edit presets (`lum(roam_sin.yx)`,
136                // MD2-accepted by implicit pad-with-zero): pad to vec3
137                // via the WGSL constructor `vec3<f32>(<arg>, 0.0)` so it
138                // satisfies the helper signature without changing the
139                // luminance value for the x/y components.
140                "lum" => {
141                    if let Some(arg) = c.args.first() {
142                        let t = infer_type(arg, ctx);
143                        if t.is_vec() {
144                            let sz = vec_size(t);
145                            if sz > 3 {
146                                ctx.emit_truncation(arg.span(), 3);
147                            } else if sz == 2 {
148                                ctx.emit_pad_vec2_to_vec3(arg.span());
149                            }
150                        }
151                    }
152                    WgslType::F32
153                }
154                "pow" => {
155                    // Both args must share size. If one is vec_n and the
156                    // other is vec_m with m > n, truncate the larger.
157                    let a = c
158                        .args
159                        .first()
160                        .map(|e| infer_type(e, ctx))
161                        .unwrap_or(WgslType::Unknown);
162                    let b = c
163                        .args
164                        .get(1)
165                        .map(|e| infer_type(e, ctx))
166                        .unwrap_or(WgslType::Unknown);
167                    if a.is_vec() && b.is_vec() {
168                        let as_ = vec_size(a);
169                        let bs = vec_size(b);
170                        if as_ != bs {
171                            let min = as_.min(bs);
172                            if as_ > min {
173                                ctx.emit_truncation(c.args[0].span(), min);
174                            } else {
175                                ctx.emit_truncation(c.args[1].span(), min);
176                            }
177                        }
178                    }
179                    a
180                }
181                _ => builtin_return(&c.callee, &c.args, ctx),
182            }
183        }
184        Expr::Binary(b) => {
185            walk_expr_for_uv(&b.lhs, ctx);
186            walk_expr_for_uv(&b.rhs, ctx);
187            widen_type(infer_type(&b.lhs, ctx), infer_type(&b.rhs, ctx))
188        }
189        Expr::Unary(u) => walk_expr_for_uv(&u.operand, ctx),
190        Expr::Ternary(t) => {
191            walk_expr_for_uv(&t.cond, ctx);
192            walk_expr_for_uv(&t.then_expr, ctx);
193            walk_expr_for_uv(&t.else_expr, ctx)
194        }
195        Expr::Swizzle(s) => {
196            walk_expr_for_uv(&s.base, ctx);
197            infer_type(e, ctx)
198        }
199        Expr::Member(m) => {
200            walk_expr_for_uv(&m.base, ctx);
201            WgslType::Unknown
202        }
203        Expr::Index(i) => {
204            walk_expr_for_uv(&i.base, ctx);
205            walk_expr_for_uv(&i.index, ctx);
206            WgslType::Unknown
207        }
208        Expr::InitList(l) => {
209            for e in &l.elems {
210                walk_expr_for_uv(e, ctx);
211            }
212            WgslType::Unknown
213        }
214        Expr::Assign(a) => {
215            walk_expr_for_uv(&a.target, ctx);
216            walk_expr_for_uv(&a.value, ctx)
217        }
218        Expr::Ident(name, _) => ctx.lookup(name),
219        Expr::Lit(_) => WgslType::F32,
220    }
221}