onedrop_hlsl/rewrite/
modf_arity.rs

1//! Pass: lower HLSL two-arg `modf(x, &out)` to WGSL's one-arg struct form.
2//!
3//! HLSL spec: `modf(x, &out)` returns the fractional part of `x` and
4//! writes the integer part (truncated toward zero) into `out`.
5//!
6//! WGSL spec: `modf(x)` returns `__modf_result_f32 { fract: f32, whole:
7//! f32 }`. The two-arg HLSL form is rejected at parse with
8//! `too many arguments passed to 'modf'`.
9//!
10//! Corpus shape (5 / 2 000 warp presets, all in the same `PutDist`
11//! helper inherited from the whoah / martin / amandio packs):
12//!
13//! ```hlsl
14//! float fg, fb;
15//! fg = modf((1-x)*255.0, fb);
16//! ```
17//!
18//! Rewrite injects a unique temp at the call site so the input is only
19//! evaluated once, then replaces the original assign with the two
20//! component reads:
21//!
22//! ```wgsl
23//! let _md2_modf_K = modf((1-x)*255.0);
24//! fb = _md2_modf_K.whole;
25//! fg = _md2_modf_K.fract;
26//! ```
27//!
28//! Only fires on `Stmt::Assign(target, =, Call("modf", [arg, out_ident]))`
29//! where `out_ident` is a bare `Expr::Ident`. Compound assigns (`+=` etc.)
30//! and out-arg shapes the user wrote as `modf(x, &out)` (with HLSL's `&`
31//! prefix) are out-of-scope — neither appears in the corpus.
32
33use super::*;
34use crate::lex::Span;
35use std::sync::atomic::{AtomicUsize, Ordering};
36
37static MODF_TEMP_COUNTER: AtomicUsize = AtomicUsize::new(0);
38
39pub(crate) fn rewrite_modf_arity(src: &str) -> String {
40    let Ok(tu) = parse_hlsl(src) else {
41        return src.to_string();
42    };
43    let mut edits = Vec::new();
44    if let Some(body) = &tu.shader_body {
45        walk_block(body, src, &mut edits);
46    }
47    for item in &tu.items {
48        if let Item::Function(f) = item {
49            walk_block(&f.body, src, &mut edits);
50        }
51    }
52    apply_edits(src, &mut edits)
53}
54
55fn walk_block(b: &Block, src: &str, edits: &mut Vec<TextEdit>) {
56    for s in &b.stmts {
57        walk_stmt(s, src, edits);
58    }
59}
60
61fn walk_stmt(s: &Stmt, src: &str, edits: &mut Vec<TextEdit>) {
62    match s {
63        Stmt::Assign(a) if matches!(a.op, AssignOp::Set) => {
64            try_emit_modf_rewrite(a, src, edits);
65        }
66        Stmt::LocalDecl(d) => {
67            if let Some(init) = &d.init {
68                try_emit_modf_in_init(d, init, src, edits);
69            }
70        }
71        Stmt::Block(b) => walk_block(b, src, edits),
72        Stmt::If(i) => {
73            walk_stmt(&i.then_branch, src, edits);
74            if let Some(e) = &i.else_branch {
75                walk_stmt(e, src, edits);
76            }
77        }
78        Stmt::While(w) => walk_stmt(&w.body, src, edits),
79        Stmt::For(f) => walk_stmt(&f.body, src, edits),
80        _ => {}
81    }
82}
83
84fn try_emit_modf_rewrite(a: &AssignStmt, src: &str, edits: &mut Vec<TextEdit>) {
85    let Expr::Call(c) = &a.value else { return };
86    if !c.callee.eq_ignore_ascii_case("modf") || c.args.len() != 2 {
87        return;
88    }
89    let Expr::Ident(_, _) = &c.args[1] else {
90        return;
91    };
92    let target_text = slice(src, a.target.span());
93    let out_text = slice(src, c.args[1].span());
94    let arg_text = slice(src, c.args[0].span());
95    let tmp = unique_temp();
96    let stmt_text = format!(
97        "let {tmp} = modf({arg_text}); {out_text} = {tmp}.whole; {target_text} = {tmp}.fract;"
98    );
99    edits.push(TextEdit {
100        start: a.span.start,
101        end: span_includes_semi(src, a.span).end,
102        replacement: stmt_text,
103    });
104}
105
106fn try_emit_modf_in_init(d: &LocalDecl, init: &Expr, src: &str, edits: &mut Vec<TextEdit>) {
107    // `float fg = modf(x, fb);` — same shape but inside a local-decl init.
108    // The LocalDecl's `span` only covers `T name` (not the init), so the
109    // replacement range has to extend to the `;` after the init expr.
110    let Expr::Call(c) = init else { return };
111    if !c.callee.eq_ignore_ascii_case("modf") || c.args.len() != 2 {
112        return;
113    }
114    let Expr::Ident(_, _) = &c.args[1] else {
115        return;
116    };
117    let ty_text = slice(src, d.ty.span);
118    let out_text = slice(src, c.args[1].span());
119    let arg_text = slice(src, c.args[0].span());
120    let tmp = unique_temp();
121    let replacement = format!(
122        "let {tmp} = modf({arg_text}); {out_text} = {tmp}.whole; {ty} {name} = {tmp}.fract;",
123        ty = ty_text,
124        name = d.name,
125    );
126    let init_end = init.span().end;
127    let stmt_end = span_includes_semi(
128        src,
129        Span {
130            start: d.span.start,
131            end: init_end,
132            line: d.span.line,
133            col: d.span.col,
134        },
135    )
136    .end;
137    edits.push(TextEdit {
138        start: d.span.start,
139        end: stmt_end,
140        replacement,
141    });
142}
143
144fn unique_temp() -> String {
145    let n = MODF_TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
146    format!("_md2_modf_{n}")
147}
148
149fn slice(src: &str, span: Span) -> &str {
150    &src[span.start as usize..span.end as usize]
151}
152
153fn span_includes_semi(src: &str, span: Span) -> Span {
154    let end = span.end as usize;
155    let bytes = src.as_bytes();
156    let mut i = end;
157    while i < bytes.len() && bytes[i].is_ascii_whitespace() {
158        i += 1;
159    }
160    if i < bytes.len() && bytes[i] == b';' {
161        Span {
162            end: (i + 1) as u32,
163            ..span
164        }
165    } else {
166        span
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::translate_shader;
174
175    #[test]
176    fn two_arg_modf_in_assignment_gets_lowered() {
177        let src = r#"shader_body {
178    float fg;
179    float fb;
180    fg = modf((1-fg)*255.0, fb);
181}"#;
182        let out = rewrite_modf_arity(src);
183        assert!(
184            out.contains("modf((1-fg)*255.0)"),
185            "single-arg call missing: {out}"
186        );
187        assert!(out.contains(".whole"), "whole field missing: {out}");
188        assert!(out.contains(".fract"), "fract field missing: {out}");
189    }
190
191    #[test]
192    fn one_arg_modf_left_alone() {
193        // The WGSL-shape call survives — no two-arg shape to detect.
194        let src = r#"shader_body { float r = modf(0.5); }"#;
195        let out = rewrite_modf_arity(src);
196        assert_eq!(out, src);
197    }
198
199    #[test]
200    fn out_arg_must_be_ident() {
201        // `modf(x, foo.bar)` is unusual and we don't try to lower it —
202        // the existing source survives, naga will report the same error.
203        let src = r#"shader_body {
204    float fg;
205    float2 v;
206    fg = modf(0.5, v.x);
207}"#;
208        let out = rewrite_modf_arity(src);
209        assert_eq!(out, src);
210    }
211
212    #[test]
213    fn translate_roundtrip_putdist_pattern() {
214        let hlsl = r#"
215float2 PutDist(float x) {
216    float fg = 0.0;
217    float fb = 0.0;
218    fg = modf((1-x)*255.0, fb);
219    return float2(fg, fb);
220}
221shader_body { ret = float3(PutDist(0.5), 0); }
222"#;
223        let wgsl = translate_shader(hlsl).expect("translates");
224        assert!(
225            wgsl.contains("modf((1-x)*255.0)"),
226            "expected single-arg modf, got:\n{wgsl}"
227        );
228        assert!(
229            wgsl.contains(".whole") && wgsl.contains(".fract"),
230            "struct field access missing:\n{wgsl}"
231        );
232        // The original two-arg form is gone.
233        assert!(
234            !wgsl.contains("modf((1-x)*255.0,fb)"),
235            "stale two-arg call: {wgsl}"
236        );
237    }
238
239    #[test]
240    fn compound_assign_with_modf_left_alone() {
241        // `fg += modf(x, fb);` — out-of-scope; the corpus uses `=` only.
242        let src = r#"shader_body {
243    float fg;
244    float fb;
245    fg += modf(0.5, fb);
246}"#;
247        let out = rewrite_modf_arity(src);
248        assert_eq!(out, src);
249    }
250
251    #[test]
252    fn modf_in_local_decl_init() {
253        let src = r#"shader_body {
254    float fb;
255    float fg = modf(0.5, fb);
256}"#;
257        let out = rewrite_modf_arity(src);
258        assert!(
259            out.contains(".whole") && out.contains(".fract"),
260            "got: {out}"
261        );
262        assert!(!out.contains("modf(0.5, fb)"), "stale two-arg: {out}");
263    }
264}