onedrop_hlsl/rewrite/
modf_arity.rs1use super::*;
34use crate::lex::Span;
35use std::sync::atomic::{AtomicUsize, Ordering};
36
37static MODF_TEMP_COUNTER: AtomicUsize = AtomicUsize::new(0);
38
39pub(crate) fn rewrite_modf_arity(src: &str) -> String {
40 let Ok(tu) = parse_hlsl(src) else {
41 return src.to_string();
42 };
43 let mut edits = Vec::new();
44 if let Some(body) = &tu.shader_body {
45 walk_block(body, src, &mut edits);
46 }
47 for item in &tu.items {
48 if let Item::Function(f) = item {
49 walk_block(&f.body, src, &mut edits);
50 }
51 }
52 apply_edits(src, &mut edits)
53}
54
55fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
56 for s in &b.stmts {
57 walk_stmt(s, src, edits);
58 }
59}
60
61fn walk_stmt(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
62 match s {
63 Stmt::Assign(a) if matches!(a.op, AssignOp::Set) => {
64 try_emit_modf_rewrite(a, src, edits);
65 }
66 Stmt::LocalDecl(d) => {
67 if let Some(init) = &d.init {
68 try_emit_modf_in_init(d, init, src, edits);
69 }
70 }
71 Stmt::Block(b) => walk_block(b, src, edits),
72 Stmt::If(i) => {
73 walk_stmt(&i.then_branch, src, edits);
74 if let Some(e) = &i.else_branch {
75 walk_stmt(e, src, edits);
76 }
77 }
78 Stmt::While(w) => walk_stmt(&w.body, src, edits),
79 Stmt::For(f) => walk_stmt(&f.body, src, edits),
80 _ => {}
81 }
82}
83
84fn try_emit_modf_rewrite(a: &AssignStmt, src: &str, edits: &mut Vec<TextEdit>) {
85 let Expr::Call(c) = &a.value else { return };
86 if !c.callee.eq_ignore_ascii_case("modf") || c.args.len() != 2 {
87 return;
88 }
89 let Expr::Ident(_, _) = &c.args[1] else {
90 return;
91 };
92 let target_text = slice(src, a.target.span());
93 let out_text = slice(src, c.args[1].span());
94 let arg_text = slice(src, c.args[0].span());
95 let tmp = unique_temp();
96 let stmt_text = format!(
97 "let {tmp} = modf({arg_text}); {out_text} = {tmp}.whole; {target_text} = {tmp}.fract;"
98 );
99 edits.push(TextEdit {
100 start: a.span.start,
101 end: span_includes_semi(src, a.span).end,
102 replacement: stmt_text,
103 });
104}
105
106fn try_emit_modf_in_init(d: &LocalDecl, init: &Expr, src: &str, edits: &mut Vec<TextEdit>) {
107 let Expr::Call(c) = init else { return };
111 if !c.callee.eq_ignore_ascii_case("modf") || c.args.len() != 2 {
112 return;
113 }
114 let Expr::Ident(_, _) = &c.args[1] else {
115 return;
116 };
117 let ty_text = slice(src, d.ty.span);
118 let out_text = slice(src, c.args[1].span());
119 let arg_text = slice(src, c.args[0].span());
120 let tmp = unique_temp();
121 let replacement = format!(
122 "let {tmp} = modf({arg_text}); {out_text} = {tmp}.whole; {ty} {name} = {tmp}.fract;",
123 ty = ty_text,
124 name = d.name,
125 );
126 let init_end = init.span().end;
127 let stmt_end = span_includes_semi(
128 src,
129 Span {
130 start: d.span.start,
131 end: init_end,
132 line: d.span.line,
133 col: d.span.col,
134 },
135 )
136 .end;
137 edits.push(TextEdit {
138 start: d.span.start,
139 end: stmt_end,
140 replacement,
141 });
142}
143
144fn unique_temp() -> String {
145 let n = MODF_TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
146 format!("_md2_modf_{n}")
147}
148
149fn slice(src: &str, span: Span) -> &str {
150 &src[span.start as usize..span.end as usize]
151}
152
153fn span_includes_semi(src: &str, span: Span) -> Span {
154 let end = span.end as usize;
155 let bytes = src.as_bytes();
156 let mut i = end;
157 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
158 i += 1;
159 }
160 if i < bytes.len() && bytes[i] == b';' {
161 Span {
162 end: (i + 1) as u32,
163 ..span
164 }
165 } else {
166 span
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::translate_shader;
174
175 #[test]
176 fn two_arg_modf_in_assignment_gets_lowered() {
177 let src = r#"shader_body {
178 float fg;
179 float fb;
180 fg = modf((1-fg)*255.0, fb);
181}"#;
182 let out = rewrite_modf_arity(src);
183 assert!(
184 out.contains("modf((1-fg)*255.0)"),
185 "single-arg call missing: {out}"
186 );
187 assert!(out.contains(".whole"), "whole field missing: {out}");
188 assert!(out.contains(".fract"), "fract field missing: {out}");
189 }
190
191 #[test]
192 fn one_arg_modf_left_alone() {
193 let src = r#"shader_body { float r = modf(0.5); }"#;
195 let out = rewrite_modf_arity(src);
196 assert_eq!(out, src);
197 }
198
199 #[test]
200 fn out_arg_must_be_ident() {
201 let src = r#"shader_body {
204 float fg;
205 float2 v;
206 fg = modf(0.5, v.x);
207}"#;
208 let out = rewrite_modf_arity(src);
209 assert_eq!(out, src);
210 }
211
212 #[test]
213 fn translate_roundtrip_putdist_pattern() {
214 let hlsl = r#"
215float2 PutDist(float x) {
216 float fg = 0.0;
217 float fb = 0.0;
218 fg = modf((1-x)*255.0, fb);
219 return float2(fg, fb);
220}
221shader_body { ret = float3(PutDist(0.5), 0); }
222"#;
223 let wgsl = translate_shader(hlsl).expect("translates");
224 assert!(
225 wgsl.contains("modf((1-x)*255.0)"),
226 "expected single-arg modf, got:\n{wgsl}"
227 );
228 assert!(
229 wgsl.contains(".whole") && wgsl.contains(".fract"),
230 "struct field access missing:\n{wgsl}"
231 );
232 assert!(
234 !wgsl.contains("modf((1-x)*255.0,fb)"),
235 "stale two-arg call: {wgsl}"
236 );
237 }
238
239 #[test]
240 fn compound_assign_with_modf_left_alone() {
241 let src = r#"shader_body {
243 float fg;
244 float fb;
245 fg += modf(0.5, fb);
246}"#;
247 let out = rewrite_modf_arity(src);
248 assert_eq!(out, src);
249 }
250
251 #[test]
252 fn modf_in_local_decl_init() {
253 let src = r#"shader_body {
254 float fb;
255 float fg = modf(0.5, fb);
256}"#;
257 let out = rewrite_modf_arity(src);
258 assert!(
259 out.contains(".whole") && out.contains(".fract"),
260 "got: {out}"
261 );
262 assert!(!out.contains("modf(0.5, fb)"), "stale two-arg: {out}");
263 }
264}