onedrop_hlsl/rewrite/
swizzle_assign.rs1use super::*;
24
25pub(crate) fn rewrite_swizzle_assigns(src: &str) -> String {
26 let Ok(tu) = parse_hlsl(src) else {
27 return src.to_string();
28 };
29 let mut state = RewriteState {
30 src,
31 edits: Vec::new(),
32 counter: 0,
33 };
34 if let Some(body) = &tu.shader_body {
35 walk_block(body, &mut state);
36 }
37 for item in &tu.items {
38 if let Item::Function(f) = item {
39 walk_block(&f.body, &mut state);
40 }
41 }
42 apply_edits(src, &mut state.edits)
43}
44
45struct RewriteState<'a> {
46 src: &'a str,
47 edits: Vec<TextEdit>,
48 counter: usize,
49}
50
51fn walk_block(b: &Block, st: &mut RewriteState) {
52 for s in &b.stmts {
53 walk_stmt(s, st);
54 }
55}
56
57fn walk_stmt(s: &Stmt, st: &mut RewriteState) {
58 match s {
59 Stmt::Assign(a) => try_emit(a, st),
60 Stmt::If(i) => {
61 walk_stmt(&i.then_branch, st);
62 if let Some(e) = &i.else_branch {
63 walk_stmt(e, st);
64 }
65 }
66 Stmt::While(w) => walk_stmt(&w.body, st),
67 Stmt::For(f) => {
68 if let Some(init) = &f.init {
69 walk_stmt(init, st);
70 }
71 walk_stmt(&f.body, st);
72 }
73 Stmt::Block(b) => walk_block(b, st),
74 Stmt::LocalDecl(_) | Stmt::Expr(_) | Stmt::Return(_) | Stmt::Break | Stmt::Continue => {}
75 }
76}
77
78fn try_emit(a: &AssignStmt, st: &mut RewriteState) {
79 let Expr::Swizzle(sw) = &a.target else {
80 return;
81 };
82 let comps = sw.components.as_str();
83 if comps.len() < 2 {
84 return; }
86 if !components_unique(comps) {
87 return; }
89 if followed_by_comma(st.src, a.span.end) {
90 return; }
92
93 let n = comps.len();
94 let temp_ty = match n {
95 2 => "float2",
96 3 => "float3",
97 4 => "float4",
98 _ => return,
99 };
100 let id = st.counter;
101 st.counter += 1;
102 let tmp = format!("_swztmp_{id}");
103 let op_str = assign_op_str(a.op);
104
105 let base_text = slice(st.src, sw.base.span());
109 let rhs_text = slice(st.src, a.value.span());
110
111 let mut out = String::new();
112 out.push_str(temp_ty);
113 out.push(' ');
114 out.push_str(&tmp);
115 out.push_str(" = (");
116 out.push_str(rhs_text);
117 out.push(')');
118 for (i, c) in comps.chars().enumerate() {
119 let lhs_comp = normalise_component(c);
120 let rhs_comp = match i {
121 0 => 'x',
122 1 => 'y',
123 2 => 'z',
124 3 => 'w',
125 _ => unreachable!(),
126 };
127 out.push_str("; ");
128 out.push_str(base_text);
129 out.push('.');
130 out.push(lhs_comp);
131 out.push_str(op_str);
132 out.push_str(&tmp);
133 out.push('.');
134 out.push(rhs_comp);
135 }
136
137 st.edits.push(TextEdit {
138 start: a.span.start,
139 end: a.span.end,
140 replacement: out,
141 });
142}
143
144fn components_unique(s: &str) -> bool {
145 let mut seen = [false; 8]; for c in s.chars() {
147 let idx = match normalise_component(c) {
148 'x' => 0,
149 'y' => 1,
150 'z' => 2,
151 'w' => 3,
152 _ => return false,
153 };
154 if seen[idx] {
155 return false;
156 }
157 seen[idx] = true;
158 }
159 true
160}
161
162fn normalise_component(c: char) -> char {
163 match c {
164 'r' => 'x',
165 'g' => 'y',
166 'b' => 'z',
167 'a' => 'w',
168 c => c,
169 }
170}
171
172fn assign_op_str(op: AssignOp) -> &'static str {
173 match op {
174 AssignOp::Set => " = ",
175 AssignOp::Add => " += ",
176 AssignOp::Sub => " -= ",
177 AssignOp::Mul => " *= ",
178 AssignOp::Div => " /= ",
179 AssignOp::Rem => " %= ",
180 }
181}
182
183fn followed_by_comma(src: &str, end: u32) -> bool {
184 let bytes = src.as_bytes();
185 let mut i = end as usize;
186 while i < bytes.len() {
187 match bytes[i] {
188 b' ' | b'\t' | b'\r' | b'\n' => i += 1,
189 b',' => return true,
190 _ => return false,
191 }
192 }
193 false
194}
195
196fn slice(src: &str, sp: Span) -> &str {
197 &src[sp.start as usize..sp.end as usize]
198}