onedrop_hlsl/rewrite/
scalar_swizzle.rs1use super::*;
4
5pub(crate) fn rewrite_scalar_swizzle(src: &str) -> String {
15 let Ok(tu) = parse_hlsl(src) else {
16 return src.to_string();
17 };
18 let mut ctx = WalkCtx::new(src);
19 ctx.seed_globals(&tu);
20 if let Some(body) = &tu.shader_body {
21 walk_block_for_scalar_swizzle(body, &mut ctx);
22 }
23 for item in &tu.items {
24 if let Item::Function(f) = item {
25 ctx.scope_push();
26 for p in &f.params {
27 ctx.declare(&p.name, type_from_typeref(&p.ty));
28 }
29 walk_block_for_scalar_swizzle(&f.body, &mut ctx);
30 ctx.scope_pop();
31 }
32 }
33 apply_edits(src, &mut ctx.edits)
34}
35
36fn walk_block_for_scalar_swizzle(b: &Block, ctx: &mut WalkCtx) {
37 ctx.scope_push();
38 for s in &b.stmts {
39 walk_stmt_for_scalar_swizzle(s, ctx);
40 }
41 ctx.scope_pop();
42}
43
44fn walk_stmt_for_scalar_swizzle(s: &Stmt, ctx: &mut WalkCtx) {
45 match s {
46 Stmt::LocalDecl(d) => {
47 ctx.declare(&d.name, type_from_typeref(&d.ty));
48 if let Some(init) = &d.init {
49 walk_expr_for_scalar_swizzle(init, ctx);
50 }
51 }
52 Stmt::Assign(a) => {
53 walk_expr_for_scalar_swizzle(&a.target, ctx);
54 walk_expr_for_scalar_swizzle(&a.value, ctx);
55 }
56 Stmt::Expr(e) => {
57 walk_expr_for_scalar_swizzle(e, ctx);
58 }
59 Stmt::If(i) => {
60 walk_expr_for_scalar_swizzle(&i.cond, ctx);
61 walk_stmt_for_scalar_swizzle(&i.then_branch, ctx);
62 if let Some(e) = &i.else_branch {
63 walk_stmt_for_scalar_swizzle(e, ctx);
64 }
65 }
66 Stmt::While(w) => {
67 walk_expr_for_scalar_swizzle(&w.cond, ctx);
68 walk_stmt_for_scalar_swizzle(&w.body, ctx);
69 }
70 Stmt::For(f) => {
71 ctx.scope_push();
72 if let Some(init) = &f.init {
73 walk_stmt_for_scalar_swizzle(init, ctx);
74 }
75 if let Some(c) = &f.cond {
76 walk_expr_for_scalar_swizzle(c, ctx);
77 }
78 if let Some(st) = &f.step {
79 walk_expr_for_scalar_swizzle(st, ctx);
80 }
81 walk_stmt_for_scalar_swizzle(&f.body, ctx);
82 ctx.scope_pop();
83 }
84 Stmt::Return(Some(e)) => {
85 walk_expr_for_scalar_swizzle(e, ctx);
86 }
87 Stmt::Block(b) => walk_block_for_scalar_swizzle(b, ctx),
88 Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
89 }
90}
91
92fn walk_expr_for_scalar_swizzle(e: &Expr, ctx: &mut WalkCtx) -> WgslType {
93 match e {
94 Expr::Swizzle(s) => {
95 let base_ty = walk_expr_for_scalar_swizzle(&s.base, ctx);
96 if matches!(base_ty, WgslType::F32 | WgslType::I32)
97 && !s.components.is_empty()
98 && s.components.chars().all(|c| c == 'x')
99 {
100 let n = s.components.len();
101 let head = match n {
102 1 => {
103 ctx.edits.push(TextEdit {
106 start: s.base.span().end,
107 end: s.span.end,
108 replacement: String::new(),
109 });
110 return base_ty;
111 }
112 2 => "float2",
113 3 => "float3",
114 4 => "float4",
115 _ => return base_ty,
116 };
117 let base_text = &ctx.src[s.base.span().start as usize..s.base.span().end as usize];
118 ctx.edits.push(TextEdit {
119 start: s.span.start,
120 end: s.span.end,
121 replacement: format!("{head}({base_text})"),
122 });
123 return vec_of_size(n);
124 }
125 if base_ty.is_vec() {
126 vec_of_size(s.components.len())
127 } else {
128 WgslType::Unknown
129 }
130 }
131 Expr::Ident(name, _) => ctx.lookup(name),
132 Expr::Lit(l) => match l.value {
133 LitValue::Int(_) | LitValue::Float(_) => WgslType::F32,
134 LitValue::Bool(_) => WgslType::Bool,
135 },
136 Expr::Binary(b) => {
137 let lt = walk_expr_for_scalar_swizzle(&b.lhs, ctx);
138 let rt = walk_expr_for_scalar_swizzle(&b.rhs, ctx);
139 widen_type(lt, rt)
140 }
141 Expr::Unary(u) => walk_expr_for_scalar_swizzle(&u.operand, ctx),
142 Expr::Ternary(t) => {
143 walk_expr_for_scalar_swizzle(&t.cond, ctx);
144 let a = walk_expr_for_scalar_swizzle(&t.then_expr, ctx);
145 let b = walk_expr_for_scalar_swizzle(&t.else_expr, ctx);
146 widen_type(a, b)
147 }
148 Expr::Call(c) => {
149 if let Some(t) = constructor_return(&c.callee) {
150 for a in &c.args {
151 walk_expr_for_scalar_swizzle(a, ctx);
152 }
153 return t;
154 }
155 for a in &c.args {
156 walk_expr_for_scalar_swizzle(a, ctx);
157 }
158 builtin_return(&c.callee, &c.args, ctx)
159 }
160 Expr::Member(m) => {
161 walk_expr_for_scalar_swizzle(&m.base, ctx);
162 WgslType::Unknown
163 }
164 Expr::Index(i) => {
165 walk_expr_for_scalar_swizzle(&i.base, ctx);
166 walk_expr_for_scalar_swizzle(&i.index, ctx);
167 WgslType::Unknown
168 }
169 Expr::InitList(l) => {
170 for e in &l.elems {
171 walk_expr_for_scalar_swizzle(e, ctx);
172 }
173 WgslType::Unknown
174 }
175 Expr::Assign(a) => {
176 walk_expr_for_scalar_swizzle(&a.target, ctx);
177 walk_expr_for_scalar_swizzle(&a.value, ctx)
178 }
179 }
180}