onedrop_eval/evaluator/
rewriters.rs

1//! AST-style rewriters consumed by [`super::MilkEvaluator::preprocess_expression`].
2//!
3//! Each `rewrite_*` / `wrap_*` function takes an expression string and
4//! returns a normalised form `evalexpr` can parse. The passes are
5//! idempotent and are run in the order documented in
6//! [`super::preprocess`].
7
8use regex::Regex;
9use std::sync::LazyLock;
10
11use super::{
12    contains_top_level_byte, contains_top_level_comma, contains_top_level_comparison,
13    is_ident_byte, is_paren_func_call_open, match_close_paren, split_top_level_byte,
14    split_top_level_commas,
15};
16
17/// Rewrite `a = b = … = <expr>` into `<last> = <expr>; <prev> = <last>; …`
18/// so evalexpr — which makes `=` return `Empty` — doesn't choke on the
19/// outer assignments. Detected via a strict left-to-right scan: each
20/// step requires a bare identifier followed by `=` and another bare
21/// identifier followed by `=`, no operators between them. Real-world
22/// presets in the corpus: `dx=dx=(y*dx)*cos(time)*…` (the LHS is
23/// repeated — a common author typo, but evalexpr still has to accept).
24pub(super) fn rewrite_chain_assignments(s: &str) -> String {
25    static CHAIN_REGEX: LazyLock<Regex> = LazyLock::new(|| {
26        Regex::new(r"(?P<head>(?:^|[;\n])\s*)(?P<a>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*(?P<b>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*(?P<rest>[^;\n]+)").unwrap()
27    });
28    // Iterate to a fixed point — `a=b=c=expr` needs two passes.
29    let mut out = s.to_string();
30    for _ in 0..4 {
31        let new = CHAIN_REGEX
32            .replace_all(&out, "${head}${b} = ${rest}; ${a} = ${b}")
33            .to_string();
34        if new == out {
35            break;
36        }
37        out = new;
38    }
39    out
40}
41
42/// Rewrite `a & b` to `band(a, b)` and `a | b` to `bor(a, b)`, leaving
43/// `&&`, `||`, `&=`, `|=` untouched. MD2 expression syntax treats single
44/// `&` / `|` as numeric short-circuit AND/OR (returning 0.0/1.0), not as
45/// bitwise ops, so we map them onto our registered `band`/`bor` builtins
46/// which already preserve those semantics. Operates iteratively: each
47/// pass rewrites one operator from the left, then the next scan picks up
48/// any nested `&`/`|` that the previous edit's right operand contained.
49pub(super) fn rewrite_amp_pipe_to_band_bor(s: &str) -> String {
50    let mut out = s.to_string();
51    // Cap iterations defensively: even pathological inputs shouldn't
52    // contain more `&`/`|` operators than `out.len()`.
53    let cap = out.len() + 16;
54    for _ in 0..cap {
55        let Some((start, op_pos, end, op)) = find_amp_pipe_operands(&out) else {
56            break;
57        };
58        let left = out[start..op_pos].trim();
59        let right = out[op_pos + 1..end].trim();
60        let func = if op == '&' { "band" } else { "bor" };
61        let head = &out[..start];
62        let tail = &out[end..];
63        out = format!("{head}{func}({left}, {right}){tail}");
64    }
65    out
66}
67
68/// Locate the first `&` or `|` in `s` that is a single (non-`&&` / non-`||`,
69/// non-`&=` / non-`|=`) operator, plus the byte offsets of its left and
70/// right operand spans. Returns `None` if no eligible operator remains or
71/// either operand boundary can't be resolved.
72fn find_amp_pipe_operands(s: &str) -> Option<(usize, usize, usize, char)> {
73    let bytes = s.as_bytes();
74    let mut i = 0;
75    while i < bytes.len() {
76        let c = bytes[i];
77        if (c == b'&' || c == b'|')
78            && bytes.get(i + 1).copied() != Some(c)
79            && bytes.get(i + 1).copied() != Some(b'=')
80            && (i == 0 || bytes[i - 1] != c)
81            && let (Some(start), Some(end)) = (
82                find_left_operand_start(bytes, i),
83                find_right_operand_end(bytes, i),
84            )
85        {
86            return Some((start, i, end, c as char));
87        }
88        i += 1;
89    }
90    None
91}
92
93fn find_left_operand_start(bytes: &[u8], op_pos: usize) -> Option<usize> {
94    let mut j = op_pos;
95    while j > 0 && matches!(bytes[j - 1], b' ' | b'\t') {
96        j -= 1;
97    }
98    if j == 0 {
99        return None;
100    }
101    let c = bytes[j - 1];
102    if c == b')' {
103        let mut depth = 1usize;
104        let mut k = j - 1;
105        while k > 0 && depth > 0 {
106            k -= 1;
107            match bytes[k] {
108                b')' => depth += 1,
109                b'(' => depth -= 1,
110                _ => {}
111            }
112        }
113        if depth != 0 {
114            return None;
115        }
116        let mut start = k;
117        while start > 0 {
118            let p = bytes[start - 1];
119            if p.is_ascii_alphanumeric() || p == b'_' {
120                start -= 1;
121            } else {
122                break;
123            }
124        }
125        Some(start)
126    } else if c.is_ascii_alphanumeric() || c == b'_' || c == b'.' {
127        let mut start = j;
128        while start > 0 {
129            let p = bytes[start - 1];
130            if p.is_ascii_alphanumeric() || p == b'_' || p == b'.' {
131                start -= 1;
132            } else {
133                break;
134            }
135        }
136        Some(start)
137    } else {
138        None
139    }
140}
141
142fn find_right_operand_end(bytes: &[u8], op_pos: usize) -> Option<usize> {
143    let mut j = op_pos + 1;
144    while j < bytes.len() && matches!(bytes[j], b' ' | b'\t') {
145        j += 1;
146    }
147    if j >= bytes.len() {
148        return None;
149    }
150    let c = bytes[j];
151    // Optional unary `-` / `+` on the right operand
152    let mut start = j;
153    if c == b'-' || c == b'+' {
154        start += 1;
155        while start < bytes.len() && matches!(bytes[start], b' ' | b'\t') {
156            start += 1;
157        }
158        if start >= bytes.len() {
159            return None;
160        }
161    }
162    let head = bytes[start];
163    if head == b'(' {
164        let mut depth = 1usize;
165        let mut k = start;
166        while k + 1 < bytes.len() && depth > 0 {
167            k += 1;
168            match bytes[k] {
169                b'(' => depth += 1,
170                b')' => depth -= 1,
171                _ => {}
172            }
173        }
174        if depth != 0 {
175            return None;
176        }
177        Some(k + 1)
178    } else if head.is_ascii_alphanumeric() || head == b'_' || head == b'.' {
179        let mut end = start;
180        while end < bytes.len() {
181            let p = bytes[end];
182            if p.is_ascii_alphanumeric() || p == b'_' || p == b'.' {
183                end += 1;
184            } else {
185                break;
186            }
187        }
188        if end < bytes.len() && bytes[end] == b'(' {
189            let mut depth = 1usize;
190            let mut k = end;
191            while k + 1 < bytes.len() && depth > 0 {
192                k += 1;
193                match bytes[k] {
194                    b'(' => depth += 1,
195                    b')' => depth -= 1,
196                    _ => {}
197                }
198            }
199            if depth != 0 {
200                return None;
201            }
202            end = k + 1;
203        }
204        Some(end)
205    } else {
206        None
207    }
208}
209
210/// Rewrite `a && b` to `band(a, b)` and `a || b` to `bor(a, b)`, preserving
211/// the corpus shape `Float && Float` (which evalexpr 13 strict-types as
212/// `Boolean op Boolean` and rejects when the operands come back as Float).
213/// Two-pass with operator-precedence respect: (1) `&&` is rewritten first
214/// (higher precedence in C/EEL2/evalexpr — `a || b && c` parses as
215/// `a || (b && c)`), then (2) `||` is rewritten.
216pub(super) fn rewrite_logical_to_bandbor(s: &str) -> String {
217    let after_and = rewrite_logical_op(s, b'&', "band");
218    rewrite_logical_op(&after_and, b'|', "bor")
219}
220
221/// Inner rewriter for one operator family. `op_char` is `b'&'` (for `&&`)
222/// or `b'|'` (for `||`); `func_name` is the destination builtin
223/// (`"band"` / `"bor"`).
224fn rewrite_logical_op(s: &str, op_char: u8, func_name: &str) -> String {
225    let mut out = s.to_string();
226    let cap = out.len() + 16;
227    for _ in 0..cap {
228        let Some(op_pos) = find_double_op_pos(&out, op_char) else {
229            break;
230        };
231        let bytes = out.as_bytes();
232        let start = find_logical_left_operand_start(bytes, op_pos);
233        let end = find_logical_right_operand_end(bytes, op_pos + 2);
234        // Skip degenerate spans (operator at edge with no operand).
235        if start == op_pos || end == op_pos + 2 {
236            break;
237        }
238        let left = out[start..op_pos].trim();
239        let right = out[op_pos + 2..end].trim();
240        let head = &out[..start];
241        let tail = &out[end..];
242        out = format!("{head}{func_name}({left}, {right}){tail}");
243    }
244    out
245}
246
247/// Locate the next `<op_char><op_char>` (i.e. `&&` or `||`) in `s`.
248pub(super) fn find_double_op_pos(s: &str, op_char: u8) -> Option<usize> {
249    let bytes = s.as_bytes();
250    let mut i = 0;
251    while i + 1 < bytes.len() {
252        if bytes[i] == op_char
253            && bytes[i + 1] == op_char
254            && bytes.get(i + 2).copied() != Some(op_char)
255            && bytes.get(i + 2).copied() != Some(b'=')
256            && (i == 0 || bytes[i - 1] != op_char)
257        {
258            return Some(i);
259        }
260        i += 1;
261    }
262    None
263}
264
265/// Walk left from `op_pos` to find the start of the logical operator's
266/// left operand. Spans comparison/arithmetic ops at depth 0 (so
267/// `rg4 > 1.2 && change2` gives `rg4 > 1.2` as the left operand). Stops
268/// at the nearest depth-0 `(`, `,`, `;`, `&&`, `||`, or assignment `=`.
269pub(super) fn find_logical_left_operand_start(bytes: &[u8], op_pos: usize) -> usize {
270    let mut depth = 0i32;
271    let mut k = op_pos;
272    while k > 0 {
273        let p = bytes[k - 1];
274        match p {
275            b')' => depth += 1,
276            b'(' if depth == 0 => return k,
277            b'(' => depth -= 1,
278            b',' | b';' if depth == 0 => return k,
279            b'&' | b'|' if depth == 0 && k >= 2 && bytes[k - 2] == p => return k,
280            b'=' if depth == 0 => {
281                // Disambiguate `=` (assignment) from the second char of a
282                // 2-char operator (`==`, `<=`, `>=`, `!=`, `+=`, ...).
283                let prev = if k >= 2 { Some(bytes[k - 2]) } else { None };
284                if !matches!(
285                    prev,
286                    Some(b'=' | b'<' | b'>' | b'!' | b'+' | b'-' | b'*' | b'/' | b'%')
287                ) {
288                    return k;
289                }
290            }
291            _ => {}
292        }
293        k -= 1;
294    }
295    0
296}
297
298/// Walk right from `op_end` (the byte just past the 2-char operator) to
299/// find the end of the right operand. Mirror of
300/// [`find_logical_left_operand_start`].
301pub(super) fn find_logical_right_operand_end(bytes: &[u8], op_end: usize) -> usize {
302    let mut depth = 0i32;
303    let mut k = op_end;
304    while k < bytes.len() {
305        let p = bytes[k];
306        match p {
307            b'(' => depth += 1,
308            b')' if depth == 0 => return k,
309            b')' => depth -= 1,
310            b',' | b';' if depth == 0 => return k,
311            b'&' | b'|' if depth == 0 && k + 1 < bytes.len() && bytes[k + 1] == p => return k,
312            _ => {}
313        }
314        k += 1;
315    }
316    bytes.len()
317}
318
319/// Rewrite unary `!x` to `bnot(x)`. evalexpr 13 type-checks the `!` operator
320/// as `Boolean → Boolean`; MD2 EEL2 treats it as numeric NOT (returns 1.0 if
321/// `x == 0.0`, else 0.0) — same semantics as `bnot`.
322pub(super) fn rewrite_unary_bang_to_bnot(s: &str) -> String {
323    let mut out = s.to_string();
324    let cap = out.len() + 16;
325    let mut search_from = 0;
326    for _ in 0..cap {
327        let Some((bang_pos, end)) = find_unary_bang_operand(&out, search_from) else {
328            break;
329        };
330        let operand = out[bang_pos + 1..end].trim();
331        let head = &out[..bang_pos];
332        let tail = &out[end..];
333        let rewritten_head_len = head.len();
334        out = format!("{head}bnot({operand}){tail}");
335        search_from = rewritten_head_len;
336    }
337    out
338}
339
340/// Locate the next unary `!` plus its operand end in `s`, starting at
341/// byte offset `from`.
342fn find_unary_bang_operand(s: &str, from: usize) -> Option<(usize, usize)> {
343    let bytes = s.as_bytes();
344    let mut i = from;
345    while i < bytes.len() {
346        if bytes[i] == b'!'
347            // Skip `!=` — that's the inequality operator.
348            && bytes.get(i + 1).copied() != Some(b'=')
349            && is_unary_bang_context(bytes, i)
350            && let Some(end) = find_right_operand_end(bytes, i)
351        {
352            return Some((i, end));
353        }
354        i += 1;
355    }
356    None
357}
358
359/// True if the byte at `pos` is `!` in unary position.
360pub(super) fn is_unary_bang_context(bytes: &[u8], pos: usize) -> bool {
361    if pos == 0 {
362        return true;
363    }
364    let mut j = pos;
365    while j > 0 && matches!(bytes[j - 1], b' ' | b'\t' | b'\n' | b'\r') {
366        j -= 1;
367    }
368    if j == 0 {
369        return true;
370    }
371    matches!(
372        bytes[j - 1],
373        b'=' | b'+'
374            | b'-'
375            | b'*'
376            | b'/'
377            | b'%'
378            | b'<'
379            | b'>'
380            | b','
381            | b';'
382            | b'('
383            | b'!'
384            | b'&'
385            | b'|'
386            | b'?'
387            | b':'
388    )
389}
390
391/// Rewrite Boolean-producing parenthesised comparisons to a Float-returning
392/// `milkif(<cmp>, 1, 0)` call. evalexpr is strictly typed: a `(a > b)`
393/// produces `Boolean`, which then dies on `*`/`/`/`+` against a Float
394/// neighbour.
395///
396/// Fires on parens whose inner expression has **one comparison and no
397/// commas** — that's enough to disambiguate from real function calls like
398/// `milkif(a>b, c, d)`.
399pub(super) fn wrap_boolean_assignment_rhs(s: &str) -> String {
400    // Also exclude `;` from the inner character class so a `name(arg; arg)`
401    // call (normalised later by `rewrite_semis_in_call_args`) is not
402    // mis-wrapped as a bare bool group.
403    static BOOL_PAREN_REGEX: LazyLock<Regex> = LazyLock::new(|| {
404        Regex::new(r"\(\s*(?P<cmp>[^(),;]*?(?:>=|<=|==|!=|>|<)[^(),;]*?)\s*\)").unwrap()
405    });
406    BOOL_PAREN_REGEX
407        .replace_all(s, "milkif(${cmp}, 1, 0)")
408        .to_string()
409}
410
411/// Paren-balanced sibling of [`wrap_boolean_assignment_rhs`]. The regex
412/// pass only matches `(cmp)` whose interior contains no `(` or `,`; the
413/// walker variant tracks paren depth and catches `(lev1-gmegabuf(1)>0)`
414/// and `(y<=(0.4+0.1*cos(mang)))`. Recurses inner-first so a chain
415/// `(a > (b > c))` gets each level wrapped. Skips `(` that follow an
416/// identifier (function-call form).
417pub(super) fn wrap_paren_balanced_cmp(s: &str) -> String {
418    let bytes = s.as_bytes();
419    let mut out = String::with_capacity(s.len());
420    let mut i = 0usize;
421    while i < bytes.len() {
422        if bytes[i] == b'('
423            && let Some(close) = match_close_paren(bytes, i)
424        {
425            let inner = &s[i + 1..close];
426            let inner_processed = wrap_paren_balanced_cmp(inner);
427            let is_call = is_paren_func_call_open(bytes, i);
428            if !is_call
429                && contains_top_level_comparison(&inner_processed)
430                && !contains_top_level_comma(&inner_processed)
431                && !contains_top_level_byte(&inner_processed, b';')
432            {
433                out.push_str("milkif(");
434                out.push_str(&inner_processed);
435                out.push_str(", 1, 0)");
436            } else {
437                out.push('(');
438                out.push_str(&inner_processed);
439                out.push(')');
440            }
441            i = close + 1;
442            continue;
443        }
444        // UTF-8 safe advance.
445        let ch = s[i..].chars().next().unwrap();
446        let cl = ch.len_utf8();
447        out.push_str(&s[i..i + cl]);
448        i += cl;
449    }
450    out
451}
452
453/// Wrap `<ident> = <rhs>` when `<rhs>` contains a top-level comparison
454/// operator. evalexpr's `=` type-checks the RHS against the LHS's stored
455/// type; auto-init seeds every variable as `Float(0.0)`, so assigning a
456/// bare `Boolean` from `rand(100) >= 30` blows up at runtime.
457pub(super) fn wrap_bare_cmp_assignment(s: &str) -> String {
458    let parts = split_top_level_byte(s, b';');
459    if parts.len() == 1 {
460        return wrap_bare_cmp_assignment_one(s).unwrap_or_else(|| s.to_string());
461    }
462    let mut out = String::with_capacity(s.len() + 16);
463    let mut first = true;
464    for part in parts {
465        if !first {
466            out.push(';');
467        }
468        first = false;
469        match wrap_bare_cmp_assignment_one(part) {
470            Some(rewritten) => out.push_str(&rewritten),
471            None => out.push_str(part),
472        }
473    }
474    out
475}
476
477/// Single-statement helper for [`wrap_bare_cmp_assignment`]. Returns
478/// `Some(rewritten)` if a `<ident> = <rhs>` shape with top-level cmp
479/// in `<rhs>` was found, `None` otherwise (caller passes through).
480fn wrap_bare_cmp_assignment_one(stmt: &str) -> Option<String> {
481    let bytes = stmt.as_bytes();
482    let leading_ws_end = bytes
483        .iter()
484        .position(|b| !matches!(*b, b' ' | b'\t' | b'\n' | b'\r'))
485        .unwrap_or(bytes.len());
486    let leading_ws = &stmt[..leading_ws_end];
487    let rest = &stmt[leading_ws_end..];
488    let rest_bytes = rest.as_bytes();
489
490    if rest_bytes.is_empty() || !(rest_bytes[0].is_ascii_alphabetic() || rest_bytes[0] == b'_') {
491        return None;
492    }
493    let mut k = 1;
494    while k < rest_bytes.len() && is_ident_byte(rest_bytes[k]) {
495        k += 1;
496    }
497    let ident_end = k;
498    while k < rest_bytes.len() && matches!(rest_bytes[k], b' ' | b'\t') {
499        k += 1;
500    }
501    if k >= rest_bytes.len() || rest_bytes[k] != b'=' {
502        return None;
503    }
504    let next = rest_bytes.get(k + 1).copied().unwrap_or(0);
505    if next == b'=' {
506        return None;
507    }
508    let lhs = &rest[..ident_end];
509    let rhs = rest[k + 1..].trim_start();
510
511    if rhs.is_empty() {
512        return None;
513    }
514    if !contains_top_level_comparison(rhs) {
515        return None;
516    }
517    if contains_top_level_comma(rhs) {
518        return None;
519    }
520    Some(format!("{}{} = milkif({}, 1, 0)", leading_ws, lhs, rhs))
521}
522
523/// Walk `s` and, for every function-call paren group, rewrite top-level
524/// `;` in the args to `,` when the args contain no top-level `,` already.
525/// Calls named `loop` / `exec2` / `exec3` / `while` are skipped because
526/// their interceptors rely on `;`-chain syntax inside the arg body.
527pub(super) fn rewrite_semis_in_call_args(s: &str) -> String {
528    let bytes = s.as_bytes();
529    let mut out = Vec::with_capacity(bytes.len());
530    let mut i = 0usize;
531    while i < bytes.len() {
532        if bytes[i] == b'('
533            && let Some(close) = match_close_paren(bytes, i)
534        {
535            let inner = &s[i + 1..close];
536            let inner_rewritten = rewrite_semis_in_call_args(inner);
537            let is_call = is_paren_func_call_open(bytes, i);
538            let skip = is_call && call_name_skips_semi_rewrite(bytes, i);
539            let final_inner = if is_call && !skip {
540                let has_comma = contains_top_level_byte(&inner_rewritten, b',');
541                let has_semi = contains_top_level_byte(&inner_rewritten, b';');
542                if !has_comma && has_semi {
543                    replace_top_level_semis_with_commas(&inner_rewritten)
544                } else {
545                    inner_rewritten
546                }
547            } else {
548                inner_rewritten
549            };
550            out.push(b'(');
551            out.extend_from_slice(final_inner.as_bytes());
552            out.push(b')');
553            i = close + 1;
554            continue;
555        }
556        out.push(bytes[i]);
557        i += 1;
558    }
559    String::from_utf8(out).unwrap_or_else(|_| s.to_string())
560}
561
562/// Return `true` when the identifier preceding `paren_pos` names a
563/// builtin that uses `;`-chain semantics inside its args (`loop`,
564/// `exec2`, `exec3`, `while`).
565pub(super) fn call_name_skips_semi_rewrite(bytes: &[u8], paren_pos: usize) -> bool {
566    let Some(name) = call_name_before_paren(bytes, paren_pos) else {
567        return false;
568    };
569    name.eq_ignore_ascii_case(b"loop")
570        || name.eq_ignore_ascii_case(b"exec2")
571        || name.eq_ignore_ascii_case(b"exec3")
572        || name.eq_ignore_ascii_case(b"while")
573}
574
575/// Return the identifier bytes immediately preceding `paren_pos`
576/// (skipping inter-token whitespace). `None` if the byte before `(`
577/// after whitespace isn't an identifier — i.e. `(` is a grouping paren.
578pub(super) fn call_name_before_paren(bytes: &[u8], paren_pos: usize) -> Option<&[u8]> {
579    let mut j = paren_pos;
580    while j > 0 && matches!(bytes[j - 1], b' ' | b'\t') {
581        j -= 1;
582    }
583    let end = j;
584    let mut start = end;
585    while start > 0 && is_ident_byte(bytes[start - 1]) {
586        start -= 1;
587    }
588    if start == end {
589        return None;
590    }
591    Some(&bytes[start..end])
592}
593
594/// Fixed arity for MD2 EEL2 builtins whose argument count is known up
595/// front. Used by [`rewrite_arity_mismatched_semis`] to decide when a
596/// mixed `,` + `;` argument list can be unambiguously normalised.
597///
598/// `loop` / `exec2` / `exec3` / `while` are intentionally absent — their
599/// bodies are `;`-chains, not fixed-arity arg lists.
600const KNOWN_ARITY: &[(&[u8], usize)] = &[
601    (b"milkif", 3),
602    (b"if", 3),
603    (b"clamp", 3),
604    (b"pow", 2),
605    (b"atan2", 2),
606    (b"fmod", 2),
607    (b"min", 2),
608    (b"max", 2),
609    (b"above", 2),
610    (b"below", 2),
611    (b"equal", 2),
612    (b"band", 2),
613    (b"bor", 2),
614    (b"sigmoid", 2),
615    (b"gmegabuf_set", 2),
616    (b"megabuf_set", 2),
617];
618
619/// Return the fixed arity of `name` if known.
620pub(super) fn known_arity_for(name: &[u8]) -> Option<usize> {
621    KNOWN_ARITY
622        .iter()
623        .find(|(n, _)| n.eq_ignore_ascii_case(name))
624        .map(|(_, a)| *a)
625}
626
627/// Count top-level `target` occurrences in `s` (at paren depth 0).
628fn count_top_level_byte(s: &str, target: u8) -> usize {
629    let mut depth = 0i32;
630    let mut n = 0usize;
631    for b in s.bytes() {
632        match b {
633            b'(' => depth += 1,
634            b')' => depth -= 1,
635            x if x == target && depth == 0 => n += 1,
636            _ => {}
637        }
638    }
639    n
640}
641
642/// Arity-aware companion of [`rewrite_semis_in_call_args`]. Walks every
643/// function-call paren group and, for builtins with a known fixed
644/// arity, converts top-level `;` separators inside the arg list to `,`
645/// when doing so unambiguously yields the expected arg count.
646///
647/// Fires only on calls where `comma_count + semi_count + 1 == arity`.
648/// The simpler "no commas, all semis" case is already handled by
649/// [`rewrite_semis_in_call_args`]; this pass catches the mixed
650/// `milkif(cond, a; b)` / `clamp(x, lo; hi)` shapes the conservative
651/// pass leaves alone. Calls in the `loop`/`exec*`/`while` skip-set are
652/// also bypassed here.
653pub(super) fn rewrite_arity_mismatched_semis(s: &str) -> String {
654    let bytes = s.as_bytes();
655    let mut out = Vec::with_capacity(bytes.len());
656    let mut i = 0usize;
657    while i < bytes.len() {
658        if bytes[i] == b'('
659            && let Some(close) = match_close_paren(bytes, i)
660        {
661            let inner = &s[i + 1..close];
662            let inner_rewritten = rewrite_arity_mismatched_semis(inner);
663            let is_call = is_paren_func_call_open(bytes, i);
664            let skip = is_call && call_name_skips_semi_rewrite(bytes, i);
665            let final_inner = if is_call && !skip {
666                let name = call_name_before_paren(bytes, i);
667                let arity = name.and_then(known_arity_for);
668                match arity {
669                    Some(n) => {
670                        let commas = count_top_level_byte(&inner_rewritten, b',');
671                        let semis = count_top_level_byte(&inner_rewritten, b';');
672                        if commas > 0 && semis > 0 && commas + semis + 1 == n {
673                            replace_top_level_semis_with_commas(&inner_rewritten)
674                        } else {
675                            inner_rewritten
676                        }
677                    }
678                    None => inner_rewritten,
679                }
680            } else {
681                inner_rewritten
682            };
683            out.push(b'(');
684            out.extend_from_slice(final_inner.as_bytes());
685            out.push(b')');
686            i = close + 1;
687            continue;
688        }
689        out.push(bytes[i]);
690        i += 1;
691    }
692    String::from_utf8(out).unwrap_or_else(|_| s.to_string())
693}
694
695/// Single-pass `;`→`,` substitution at paren depth 0 of `s`.
696pub(super) fn replace_top_level_semis_with_commas(s: &str) -> String {
697    let mut out = Vec::with_capacity(s.len());
698    let mut depth = 0i32;
699    for b in s.bytes() {
700        match b {
701            b'(' => {
702                depth += 1;
703                out.push(b'(');
704            }
705            b')' => {
706                depth -= 1;
707                out.push(b')');
708            }
709            b';' if depth == 0 => out.push(b','),
710            other => out.push(other),
711        }
712    }
713    String::from_utf8(out).unwrap_or_else(|_| s.to_string())
714}
715
716/// Wrap each comma-separated arg of a function call whose body still
717/// contains a top-level `;` in `(...)`, so evalexpr parses the arg as a
718/// single value-producing statement chain.
719///
720/// Targets shapes like `milkif(cond, a=…; b=…; c=…, else)`, turning the
721/// THEN-branch `;`-chain into `(a=…; b=…; c=…)`. Trailing `;` is dropped
722/// inside the wrap so the chain produces a value rather than `Empty`.
723pub(super) fn wrap_chain_args_in_parens(s: &str) -> String {
724    let bytes = s.as_bytes();
725    let mut out = Vec::with_capacity(bytes.len());
726    let mut i = 0usize;
727    while i < bytes.len() {
728        if bytes[i] == b'('
729            && let Some(close) = match_close_paren(bytes, i)
730        {
731            let inner = &s[i + 1..close];
732            let inner_rewritten = wrap_chain_args_in_parens(inner);
733            let is_call = is_paren_func_call_open(bytes, i);
734            let skip = is_call && call_name_skips_semi_rewrite(bytes, i);
735            let final_inner = if is_call && !skip {
736                wrap_semi_chain_args(&inner_rewritten)
737            } else {
738                inner_rewritten
739            };
740            out.push(b'(');
741            out.extend_from_slice(final_inner.as_bytes());
742            out.push(b')');
743            i = close + 1;
744            continue;
745        }
746        out.push(bytes[i]);
747        i += 1;
748    }
749    String::from_utf8(out).unwrap_or_else(|_| s.to_string())
750}
751
752/// Split `args` on each top-level `,` and wrap any arg whose interior
753/// contains a top-level `;` in `(...)`. Helper for
754/// [`wrap_chain_args_in_parens`].
755fn wrap_semi_chain_args(args: &str) -> String {
756    let pieces = split_top_level_commas(args);
757    if pieces.len() < 2 {
758        return args.to_string();
759    }
760    let needs_wrap = pieces
761        .iter()
762        .any(|p| contains_top_level_byte(p, b';') && !arg_already_wrapped(p));
763    if !needs_wrap {
764        return args.to_string();
765    }
766    let mut out = String::with_capacity(args.len() + pieces.len() * 2);
767    for (idx, piece) in pieces.iter().enumerate() {
768        if idx > 0 {
769            out.push(',');
770        }
771        if contains_top_level_byte(piece, b';') && !arg_already_wrapped(piece) {
772            let trimmed = piece.trim();
773            let stripped = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
774            out.push('(');
775            out.push_str(stripped);
776            out.push(')');
777        } else {
778            out.push_str(piece);
779        }
780    }
781    out
782}
783
784/// Return `true` when `s` (already a comma-piece) consists of a single
785/// `(...)` group at depth 0. Detects the idempotent case where
786/// [`wrap_semi_chain_args`] has already run.
787pub(super) fn arg_already_wrapped(s: &str) -> bool {
788    let t = s.trim();
789    let bytes = t.as_bytes();
790    if bytes.len() < 2 || bytes[0] != b'(' || bytes[bytes.len() - 1] != b')' {
791        return false;
792    }
793    match_close_paren(bytes, 0) == Some(bytes.len() - 1)
794}
795
796/// Rewrite Python-style chained comparisons (`A > B <= C <= 1`) into the
797/// explicit AND-chain (`(A > B) && (B <= C) && (C <= 1)`) that evalexpr's
798/// left-associative comparison precedence can type-check.
799///
800/// MD2's EEL2 historically reads chained comparisons as pairwise AND, like
801/// Python: `a < b < c` means "`a < b` AND `b < c`". evalexpr 13 inherits C
802/// semantics, so `A > B <= C` evaluates to `(A > B) <= C`, which is
803/// `Boolean <= C` and dies.
804pub(super) fn rewrite_chained_comparisons(s: &str) -> String {
805    let bytes = s.as_bytes();
806    let mut rewritten = Vec::with_capacity(bytes.len());
807    let mut i = 0usize;
808    while i < bytes.len() {
809        if bytes[i] == b'('
810            && let Some(close) = match_close_paren(bytes, i)
811        {
812            let inner = &s[i + 1..close];
813            let inner_rewritten = rewrite_chained_comparisons(inner);
814            rewritten.push(b'(');
815            rewritten.extend_from_slice(inner_rewritten.as_bytes());
816            rewritten.push(b')');
817            i = close + 1;
818            continue;
819        }
820        rewritten.push(bytes[i]);
821        i += 1;
822    }
823    let walked = String::from_utf8(rewritten).unwrap_or_else(|_| s.to_string());
824    // Apply the chain rewrite at this level, splitting on top-level commas
825    // so call args are processed independently.
826    let pieces = split_top_level_commas(&walked);
827    if pieces.len() == 1 {
828        return rewrite_chained_at_this_level(&walked);
829    }
830    let mut out = String::with_capacity(walked.len() + 16);
831    for (idx, piece) in pieces.iter().enumerate() {
832        if idx > 0 {
833            out.push(',');
834        }
835        out.push_str(&rewrite_chained_at_this_level(piece));
836    }
837    out
838}
839
840/// Apply the chained-comparison rewrite to a single comma-piece.
841fn rewrite_chained_at_this_level(s: &str) -> String {
842    let ops = top_level_comparison_ops(s);
843    if ops.len() < 2 {
844        return s.to_string();
845    }
846    let bytes = s.as_bytes();
847    let mut operands: Vec<&str> = Vec::with_capacity(ops.len() + 1);
848    let mut cursor = 0usize;
849    for (start, end) in &ops {
850        operands.push(s[cursor..*start].trim());
851        cursor = *end;
852    }
853    operands.push(s[cursor..].trim());
854    let op_strs: Vec<&str> = ops
855        .iter()
856        .map(|(start, end)| std::str::from_utf8(&bytes[*start..*end]).unwrap_or("=="))
857        .collect();
858    let mut out = String::with_capacity(s.len() + ops.len() * 8);
859    for k in 0..ops.len() {
860        if k > 0 {
861            out.push_str(" && ");
862        }
863        out.push('(');
864        out.push_str(operands[k]);
865        out.push_str(op_strs[k]);
866        out.push_str(operands[k + 1]);
867        out.push(')');
868    }
869    out
870}
871
872/// Return a list of `(start, end)` byte offsets for every top-level
873/// comparison operator in `s` (`<`, `>`, `<=`, `>=`, `==`, `!=`).
874pub(super) fn top_level_comparison_ops(s: &str) -> Vec<(usize, usize)> {
875    let bytes = s.as_bytes();
876    let mut ops = Vec::new();
877    let mut depth = 0i32;
878    let mut i = 0usize;
879    while i < bytes.len() {
880        match bytes[i] {
881            b'(' => {
882                depth += 1;
883                i += 1;
884            }
885            b')' => {
886                depth -= 1;
887                i += 1;
888            }
889            b'<' | b'>' if depth == 0 => {
890                let end = if bytes.get(i + 1) == Some(&b'=') {
891                    i + 2
892                } else {
893                    i + 1
894                };
895                ops.push((i, end));
896                i = end;
897            }
898            b'=' if depth == 0 && bytes.get(i + 1) == Some(&b'=') => {
899                ops.push((i, i + 2));
900                i += 2;
901            }
902            b'!' if depth == 0 && bytes.get(i + 1) == Some(&b'=') => {
903                ops.push((i, i + 2));
904                i += 2;
905            }
906            _ => {
907                i += 1;
908            }
909        }
910    }
911    ops
912}