onedrop_hlsl/rewrite/
scalar_swizzle.rs

1//! Pass: scalar `.xxx` swizzle broadcast → vec constructor.
2
3use super::*;
4
5// ---------------------------------------------------------------------------
6// Pass: scalar `.xxx`/`.xx`/`.x` swizzle broadcast → `floatN(x)`
7// ---------------------------------------------------------------------------
8
9/// HLSL allows `<scalar>.xxx` to broadcast a scalar to a `float3` (and `.xx`
10/// to `float2`, etc.). WGSL has no swizzle on scalars — naga rejects with
11/// `invalid field accessor 'xxx'` or stranger `expected ')'; found 'x3'`
12/// tokenisation. Rewrite to a constructor call `floatN(<scalar>)`; the
13/// regex `replace_types` pass then lowers the `float3` to `vec3<f32>`.
14pub(crate) fn rewrite_scalar_swizzle(src: &str) -> String {
15    let Ok(tu) = parse_hlsl(src) else {
16        return src.to_string();
17    };
18    let mut ctx = WalkCtx::new(src);
19    ctx.seed_globals(&tu);
20    if let Some(body) = &tu.shader_body {
21        walk_block_for_scalar_swizzle(body, &mut ctx);
22    }
23    for item in &tu.items {
24        if let Item::Function(f) = item {
25            ctx.scope_push();
26            for p in &f.params {
27                ctx.declare(&p.name, type_from_typeref(&p.ty));
28            }
29            walk_block_for_scalar_swizzle(&f.body, &mut ctx);
30            ctx.scope_pop();
31        }
32    }
33    apply_edits(src, &mut ctx.edits)
34}
35
36fn walk_block_for_scalar_swizzle(b: &Block, ctx: &mut WalkCtx) {
37    ctx.scope_push();
38    for s in &b.stmts {
39        walk_stmt_for_scalar_swizzle(s, ctx);
40    }
41    ctx.scope_pop();
42}
43
44fn walk_stmt_for_scalar_swizzle(s: &Stmt, ctx: &mut WalkCtx) {
45    match s {
46        Stmt::LocalDecl(d) => {
47            ctx.declare(&d.name, type_from_typeref(&d.ty));
48            if let Some(init) = &d.init {
49                walk_expr_for_scalar_swizzle(init, ctx);
50            }
51        }
52        Stmt::Assign(a) => {
53            walk_expr_for_scalar_swizzle(&a.target, ctx);
54            walk_expr_for_scalar_swizzle(&a.value, ctx);
55        }
56        Stmt::Expr(e) => {
57            walk_expr_for_scalar_swizzle(e, ctx);
58        }
59        Stmt::If(i) => {
60            walk_expr_for_scalar_swizzle(&i.cond, ctx);
61            walk_stmt_for_scalar_swizzle(&i.then_branch, ctx);
62            if let Some(e) = &i.else_branch {
63                walk_stmt_for_scalar_swizzle(e, ctx);
64            }
65        }
66        Stmt::While(w) => {
67            walk_expr_for_scalar_swizzle(&w.cond, ctx);
68            walk_stmt_for_scalar_swizzle(&w.body, ctx);
69        }
70        Stmt::For(f) => {
71            ctx.scope_push();
72            if let Some(init) = &f.init {
73                walk_stmt_for_scalar_swizzle(init, ctx);
74            }
75            if let Some(c) = &f.cond {
76                walk_expr_for_scalar_swizzle(c, ctx);
77            }
78            if let Some(st) = &f.step {
79                walk_expr_for_scalar_swizzle(st, ctx);
80            }
81            walk_stmt_for_scalar_swizzle(&f.body, ctx);
82            ctx.scope_pop();
83        }
84        Stmt::Return(Some(e)) => {
85            walk_expr_for_scalar_swizzle(e, ctx);
86        }
87        Stmt::Block(b) => walk_block_for_scalar_swizzle(b, ctx),
88        Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
89    }
90}
91
92fn walk_expr_for_scalar_swizzle(e: &Expr, ctx: &mut WalkCtx) -> WgslType {
93    match e {
94        Expr::Swizzle(s) => {
95            let base_ty = walk_expr_for_scalar_swizzle(&s.base, ctx);
96            if matches!(base_ty, WgslType::F32 | WgslType::I32)
97                && !s.components.is_empty()
98                && s.components.chars().all(|c| c == 'x')
99            {
100                let n = s.components.len();
101                let head = match n {
102                    1 => {
103                        // `<scalar>.x` is identity — drop the suffix entirely
104                        // (more conservative than wrapping in `float(x)`).
105                        ctx.edits.push(TextEdit {
106                            start: s.base.span().end,
107                            end: s.span.end,
108                            replacement: String::new(),
109                        });
110                        return base_ty;
111                    }
112                    2 => "float2",
113                    3 => "float3",
114                    4 => "float4",
115                    _ => return base_ty,
116                };
117                let base_text = &ctx.src[s.base.span().start as usize..s.base.span().end as usize];
118                ctx.edits.push(TextEdit {
119                    start: s.span.start,
120                    end: s.span.end,
121                    replacement: format!("{head}({base_text})"),
122                });
123                return vec_of_size(n);
124            }
125            if base_ty.is_vec() {
126                vec_of_size(s.components.len())
127            } else {
128                WgslType::Unknown
129            }
130        }
131        Expr::Ident(name, _) => ctx.lookup(name),
132        Expr::Lit(l) => match l.value {
133            LitValue::Int(_) | LitValue::Float(_) => WgslType::F32,
134            LitValue::Bool(_) => WgslType::Bool,
135        },
136        Expr::Binary(b) => {
137            let lt = walk_expr_for_scalar_swizzle(&b.lhs, ctx);
138            let rt = walk_expr_for_scalar_swizzle(&b.rhs, ctx);
139            widen_type(lt, rt)
140        }
141        Expr::Unary(u) => walk_expr_for_scalar_swizzle(&u.operand, ctx),
142        Expr::Ternary(t) => {
143            walk_expr_for_scalar_swizzle(&t.cond, ctx);
144            let a = walk_expr_for_scalar_swizzle(&t.then_expr, ctx);
145            let b = walk_expr_for_scalar_swizzle(&t.else_expr, ctx);
146            widen_type(a, b)
147        }
148        Expr::Call(c) => {
149            if let Some(t) = constructor_return(&c.callee) {
150                for a in &c.args {
151                    walk_expr_for_scalar_swizzle(a, ctx);
152                }
153                return t;
154            }
155            for a in &c.args {
156                walk_expr_for_scalar_swizzle(a, ctx);
157            }
158            builtin_return(&c.callee, &c.args, ctx)
159        }
160        Expr::Member(m) => {
161            walk_expr_for_scalar_swizzle(&m.base, ctx);
162            WgslType::Unknown
163        }
164        Expr::Index(i) => {
165            walk_expr_for_scalar_swizzle(&i.base, ctx);
166            walk_expr_for_scalar_swizzle(&i.index, ctx);
167            WgslType::Unknown
168        }
169        Expr::InitList(l) => {
170            for e in &l.elems {
171                walk_expr_for_scalar_swizzle(e, ctx);
172            }
173            WgslType::Unknown
174        }
175        Expr::Assign(a) => {
176            walk_expr_for_scalar_swizzle(&a.target, ctx);
177            walk_expr_for_scalar_swizzle(&a.value, ctx)
178        }
179    }
180}