onedrop_eval/evaluator/
mod.rs

1//! Evaluator for Milkdrop expressions.
2//!
3//! [`MilkEvaluator`] is the entry point. Most callers go through
4//! [`MilkEvaluator::eval`] (single expression), [`MilkEvaluator::compile`] +
5//! [`MilkEvaluator::eval_compiled`] (precompiled hot path), or
6//! [`MilkEvaluator::eval_equation_list`] (paren-balance-aware join over a
7//! preset's flat equation list).
8//!
9//! The crate's internals are split across:
10//!
11//! - [`preprocess`] — the regex-driven preprocessing pipeline that turns MD2
12//!   EEL2 source into something `evalexpr` will parse.
13//! - [`rewriters`] — AST-style rewriters consumed by `preprocess_expression`.
14//! - [`gmegabuf`] — `gmegabuf` / `megabuf` write rewrite.
15//! - [`joiner`] — multi-equation join logic and top-level `loop`/`exec`/
16//!   `while` interceptors.
17
18mod gmegabuf;
19mod joiner;
20mod preprocess;
21mod rewriters;
22
23#[cfg(test)]
24mod tests;
25
26use crate::context::MilkContext;
27use crate::error::{EvalError, Result};
28use evalexpr::Node;
29
30pub use joiner::EvaluationStats;
31
32use preprocess::MAX_EXPRESSION_LENGTH;
33
34/// One iteration of a custom-wave (or custom-shape) per-point loop. The same
35/// struct is used for input (caller seeds the loop variables) and output
36/// (evaluator reads `x`, `y`, `r`, `g`, `b`, `a` back from the context after
37/// running the equations). `sample`, `value1`, `value2` are inputs only.
38#[derive(Debug, Clone, Copy, PartialEq)]
39pub struct WavePoint {
40    /// Normalised sample index in `[0.0, 1.0]` (or `0.0` when `samples == 1`).
41    pub sample: f64,
42    /// Left-channel value (raw audio sample, or FFT bin if `b_spectrum`).
43    pub value1: f64,
44    /// Right-channel value (mirrored from value1 with N/2 offset for mono).
45    pub value2: f64,
46    pub x: f64,
47    pub y: f64,
48    pub r: f64,
49    pub g: f64,
50    pub b: f64,
51    pub a: f64,
52}
53
54impl Default for WavePoint {
55    fn default() -> Self {
56        Self {
57            sample: 0.0,
58            value1: 0.0,
59            value2: 0.0,
60            x: 0.5,
61            y: 0.5,
62            r: 1.0,
63            g: 1.0,
64            b: 1.0,
65            a: 1.0,
66        }
67    }
68}
69
70/// One iteration of a custom-shape (`shapecode_N`) per-instance loop.
71/// MD2 shape per-frame equations can mutate any of these fields and the
72/// next instance's loop body sees the seed values fresh from the
73/// preset's scalar block — instances do NOT carry state across each
74/// other (unlike `WavePoint`). Persistent state across instances lives
75/// in `q*` / `t*` channels.
76///
77/// `instance` is the 0-based loop counter; the eval seeds it (and
78/// `sides` + `num_inst`) into the context before each call so equations
79/// like `ang = ang + 0.1 * instance` work.
80#[derive(Debug, Clone, Copy, PartialEq)]
81pub struct ShapeInstance {
82    pub instance: f64,
83    pub num_inst: f64,
84    pub sides: f64,
85    pub x: f64,
86    pub y: f64,
87    pub rad: f64,
88    pub ang: f64,
89    pub tex_zoom: f64,
90    pub tex_ang: f64,
91    pub r: f64,
92    pub g: f64,
93    pub b: f64,
94    pub a: f64,
95    pub r2: f64,
96    pub g2: f64,
97    pub b2: f64,
98    pub a2: f64,
99    pub border_r: f64,
100    pub border_g: f64,
101    pub border_b: f64,
102    pub border_a: f64,
103    /// MD2 `thick` flag (drawn as bool, kept f64 to round-trip via the
104    /// eval context with no extra conversions). 0 = thin outline, 1 = thick.
105    pub thick: f64,
106    /// MD2 `additive` flag (0 = alpha blend, 1 = additive).
107    pub additive: f64,
108}
109
110impl Default for ShapeInstance {
111    fn default() -> Self {
112        Self {
113            instance: 0.0,
114            num_inst: 1.0,
115            sides: 4.0,
116            x: 0.5,
117            y: 0.5,
118            rad: 0.1,
119            ang: 0.0,
120            tex_zoom: 1.0,
121            tex_ang: 0.0,
122            r: 1.0,
123            g: 1.0,
124            b: 1.0,
125            a: 1.0,
126            r2: 0.0,
127            g2: 0.0,
128            b2: 0.0,
129            a2: 0.0,
130            border_r: 1.0,
131            border_g: 1.0,
132            border_b: 1.0,
133            border_a: 0.0,
134            thick: 0.0,
135            additive: 0.0,
136        }
137    }
138}
139
140/// Evaluator for Milkdrop expressions.
141pub struct MilkEvaluator {
142    /// Execution context
143    pub(super) context: MilkContext,
144
145    /// Compiled expressions cache
146    #[allow(dead_code)]
147    compiled_cache: Vec<(String, Node)>,
148}
149
150/// Custom `Clone` impl that skips `compiled_cache`. The cache is per-call
151/// memoisation for the regex-preprocess path; worker threads (the wave
152/// and warp parallel passes) never re-enter that path — they're given
153/// pre-compiled `Node`s by the caller — so cloning the cache is pure
154/// overhead. The cache can grow into the hundreds of entries on a
155/// long-running engine, so this matters: cloning a 100-entry
156/// `Vec<(String, Node)>` is ~30 µs that goes away for free.
157impl Clone for MilkEvaluator {
158    fn clone(&self) -> Self {
159        Self {
160            context: self.context.clone(),
161            compiled_cache: Vec::new(),
162        }
163    }
164}
165
166impl MilkEvaluator {
167    /// Create a new evaluator.
168    pub fn new() -> Self {
169        Self {
170            context: MilkContext::new(),
171            compiled_cache: Vec::new(),
172        }
173    }
174
175    /// Get a reference to the context.
176    pub fn context(&self) -> &MilkContext {
177        &self.context
178    }
179
180    /// Get a mutable reference to the context.
181    pub fn context_mut(&mut self) -> &mut MilkContext {
182        &mut self.context
183    }
184
185    /// Pre-compile an expression into an evalexpr `Node` that can be evaluated
186    /// repeatedly without re-parsing.
187    ///
188    /// Pre-processing (auto-init of undefined variables, integer→float
189    /// promotion in assignments, `if(` → `milkif(` rewrite) runs as part of
190    /// compilation, so any variable referenced in `expression` is also
191    /// registered in `self.context` as a side effect — exactly like a normal
192    /// `eval` call. This means a context cloned out of this evaluator after a
193    /// `compile` call already contains every var the compiled `Node` will
194    /// look up at evaluation time.
195    pub fn compile(&mut self, expression: &str) -> Result<Node> {
196        if expression.len() > MAX_EXPRESSION_LENGTH {
197            return Err(EvalError::SyntaxError {
198                expression: expression.chars().take(100).collect(),
199                reason: format!(
200                    "Expression too long: {} bytes (max {})",
201                    expression.len(),
202                    MAX_EXPRESSION_LENGTH
203                ),
204            });
205        }
206        let expr = expression.trim().trim_end_matches(';').trim();
207        let processed = self.preprocess_expression(expr);
208        evalexpr::build_operator_tree(&processed).map_err(|e| EvalError::SyntaxError {
209            expression: expr.to_string(),
210            reason: e.to_string(),
211        })
212    }
213
214    /// Evaluate a single expression.
215    pub fn eval(&mut self, expression: &str) -> Result<f64> {
216        if expression.len() > MAX_EXPRESSION_LENGTH {
217            return Err(EvalError::SyntaxError {
218                expression: expression.chars().take(100).collect(),
219                reason: format!(
220                    "Expression too long: {} bytes (max {})",
221                    expression.len(),
222                    MAX_EXPRESSION_LENGTH
223                ),
224            });
225        }
226
227        let expr = expression.trim().trim_end_matches(';').trim();
228
229        if expr.is_empty() {
230            return Ok(0.0);
231        }
232
233        let processed_expr = self.preprocess_expression(expr);
234
235        // Intercept top-level `loop(N, body)` (and `exec2`/`exec3`/`while`)
236        // calls *before* evalexpr sees them. evalexpr can't represent these
237        // EEL2 idioms natively (their bodies are `;`-chains inside what looks
238        // like a function-call arg list).
239        self.eval_processed_with_loops(&processed_expr, &processed_expr)
240    }
241
242    /// Evaluate multiple expressions (per-frame equations).
243    pub fn eval_per_frame(&mut self, equations: &[String]) -> Result<()> {
244        for equation in equations {
245            self.eval(equation)?;
246        }
247        Ok(())
248    }
249
250    /// Evaluate per-pixel equations for a single pixel.
251    pub fn eval_per_pixel(
252        &mut self,
253        x: f64,
254        y: f64,
255        rad: f64,
256        ang: f64,
257        equations: &[String],
258    ) -> Result<()> {
259        self.context.set_pixel(x, y, rad, ang);
260        for equation in equations {
261            self.eval(equation)?;
262        }
263        Ok(())
264    }
265
266    /// Compile a batch of source strings into reusable `Node`s. Use this on
267    /// load_preset for custom wave/shape equations, which run hundreds of
268    /// times per frame — avoiding per-eval reparse is a 10×+ speedup.
269    pub fn compile_batch(&mut self, equations: &[String]) -> Result<Vec<Node>> {
270        equations.iter().map(|eq| self.compile(eq)).collect()
271    }
272
273    /// Evaluate a pre-compiled `Node` against this evaluator's context. Skips
274    /// the regex preprocess that `eval()` runs every call — variables referenced
275    /// in the source string were already auto-initialized at `compile()` time.
276    pub fn eval_compiled(&mut self, node: &Node) -> Result<f64> {
277        match node.eval_with_context_mut(&mut self.context) {
278            Ok(value) => match value {
279                evalexpr::Value::Float(f) => Ok(f),
280                evalexpr::Value::Int(i) => Ok(i as f64),
281                evalexpr::Value::Boolean(b) => Ok(if b { 1.0 } else { 0.0 }),
282                evalexpr::Value::Empty => Ok(0.0),
283                _ => Err(EvalError::TypeError {
284                    expected: "number".to_string(),
285                    got: format!("{:?}", value),
286                }),
287            },
288            Err(e) => Err(EvalError::SyntaxError {
289                expression: "<compiled>".to_string(),
290                reason: e.to_string(),
291            }),
292        }
293    }
294
295    /// Evaluate per-frame equations for one instance of a custom shape
296    /// (`shapecode_N`). Seeds the loop variables (`instance`, `sides`,
297    /// `num_inst`) plus the full geometry/colour state, runs the
298    /// compiled per-frame Nodes, reads everything back into a fresh
299    /// `ShapeInstance`.
300    ///
301    /// Unlike `eval_per_point` (which threads state across iterations
302    /// within a frame), each shape instance is independent: callers
303    /// pass the shape's scalar block seed every time and `instance`
304    /// changes per call. Persistent state across instances lives in
305    /// `q*` / `t*` channels, which the context carries.
306    pub fn eval_per_shape_instance(
307        &mut self,
308        instance: ShapeInstance,
309        compiled: &[Node],
310    ) -> Result<ShapeInstance> {
311        Self::seed_shape_instance(&mut self.context, instance);
312        for node in compiled {
313            self.eval_compiled(node)?;
314        }
315        Ok(Self::read_shape_instance(&self.context, instance))
316    }
317
318    /// Bytecode-aware companion of [`eval_per_shape_instance`]. Runs the
319    /// bytecode VM when [`CompiledBlock::bytecode`] is `Some`, otherwise
320    /// walks the evalexpr `Node` list.
321    pub fn run_per_shape_instance(
322        &mut self,
323        instance: ShapeInstance,
324        block: &crate::compiled_block::CompiledBlock,
325    ) -> Result<ShapeInstance> {
326        Self::seed_shape_instance(&mut self.context, instance);
327        if let Some(bc) = block.bytecode() {
328            bc.run(&mut self.context);
329        } else {
330            for node in block.nodes() {
331                self.eval_compiled(node)?;
332            }
333        }
334        Ok(Self::read_shape_instance(&self.context, instance))
335    }
336
337    /// Push the 23 hot-var seed values from `instance` into `ctx`.
338    #[inline]
339    fn seed_shape_instance(ctx: &mut MilkContext, instance: ShapeInstance) {
340        ctx.set("instance", instance.instance);
341        ctx.set("num_inst", instance.num_inst);
342        ctx.set("sides", instance.sides);
343        ctx.set("x", instance.x);
344        ctx.set("y", instance.y);
345        ctx.set("rad", instance.rad);
346        ctx.set("ang", instance.ang);
347        ctx.set("tex_zoom", instance.tex_zoom);
348        ctx.set("tex_ang", instance.tex_ang);
349        ctx.set("r", instance.r);
350        ctx.set("g", instance.g);
351        ctx.set("b", instance.b);
352        ctx.set("a", instance.a);
353        ctx.set("r2", instance.r2);
354        ctx.set("g2", instance.g2);
355        ctx.set("b2", instance.b2);
356        ctx.set("a2", instance.a2);
357        ctx.set("border_r", instance.border_r);
358        ctx.set("border_g", instance.border_g);
359        ctx.set("border_b", instance.border_b);
360        ctx.set("border_a", instance.border_a);
361        ctx.set("thick", instance.thick);
362        ctx.set("additive", instance.additive);
363    }
364
365    /// Companion to [`seed_shape_instance`] — read every per-instance var
366    /// back, falling back to `instance.*` when a slot is missing.
367    #[inline]
368    fn read_shape_instance(c: &MilkContext, instance: ShapeInstance) -> ShapeInstance {
369        ShapeInstance {
370            instance: instance.instance,
371            num_inst: instance.num_inst,
372            sides: c.get("sides").unwrap_or(instance.sides),
373            x: c.get("x").unwrap_or(instance.x),
374            y: c.get("y").unwrap_or(instance.y),
375            rad: c.get("rad").unwrap_or(instance.rad),
376            ang: c.get("ang").unwrap_or(instance.ang),
377            tex_zoom: c.get("tex_zoom").unwrap_or(instance.tex_zoom),
378            tex_ang: c.get("tex_ang").unwrap_or(instance.tex_ang),
379            r: c.get("r").unwrap_or(instance.r),
380            g: c.get("g").unwrap_or(instance.g),
381            b: c.get("b").unwrap_or(instance.b),
382            a: c.get("a").unwrap_or(instance.a),
383            r2: c.get("r2").unwrap_or(instance.r2),
384            g2: c.get("g2").unwrap_or(instance.g2),
385            b2: c.get("b2").unwrap_or(instance.b2),
386            a2: c.get("a2").unwrap_or(instance.a2),
387            border_r: c.get("border_r").unwrap_or(instance.border_r),
388            border_g: c.get("border_g").unwrap_or(instance.border_g),
389            border_b: c.get("border_b").unwrap_or(instance.border_b),
390            border_a: c.get("border_a").unwrap_or(instance.border_a),
391            thick: c.get("thick").unwrap_or(instance.thick),
392            additive: c.get("additive").unwrap_or(instance.additive),
393        }
394    }
395
396    /// Bytecode counterpart of [`eval_per_point`]. Skips both the
397    /// `String → Value::Float` wrapping and the recursive operator-tree
398    /// walk that evalexpr does for each child node — the bytecode VM
399    /// just reads/writes hot slots by index and runs a flat opcode stream.
400    pub fn run_per_point_bc(
401        &mut self,
402        point: WavePoint,
403        compiled: &crate::bytecode::CompiledBytecode,
404    ) -> WavePoint {
405        let ctx = &mut self.context;
406        ctx.set("sample", point.sample);
407        ctx.set("value1", point.value1);
408        ctx.set("value2", point.value2);
409        ctx.set("x", point.x);
410        ctx.set("y", point.y);
411        ctx.set("r", point.r);
412        ctx.set("g", point.g);
413        ctx.set("b", point.b);
414        ctx.set("a", point.a);
415        compiled.run(ctx);
416        let c = &self.context;
417        WavePoint {
418            sample: point.sample,
419            value1: point.value1,
420            value2: point.value2,
421            x: c.get("x").unwrap_or(point.x),
422            y: c.get("y").unwrap_or(point.y),
423            r: c.get("r").unwrap_or(point.r),
424            g: c.get("g").unwrap_or(point.g),
425            b: c.get("b").unwrap_or(point.b),
426            a: c.get("a").unwrap_or(point.a),
427        }
428    }
429
430    /// Evaluate per-point equations for one sample inside a custom wave or
431    /// shape loop. `point` carries the input vars (`sample`, `value1`,
432    /// `value2`) and the carry-over geometry/colour (`x`, `y`, `r`, `g`,
433    /// `b`, `a`). The output echoes the same fields back, read from the
434    /// context after the equations ran.
435    ///
436    /// The caller is responsible for seeding the loop on point 0 with the
437    /// wave's base colour and any saved `x`/`y`, then threading the previous
438    /// point's output into the next call — this matches MD2's "trail across
439    /// samples within a frame" semantics.
440    pub fn eval_per_point(&mut self, point: WavePoint, compiled: &[Node]) -> Result<WavePoint> {
441        let ctx = &mut self.context;
442        ctx.set("sample", point.sample);
443        ctx.set("value1", point.value1);
444        ctx.set("value2", point.value2);
445        ctx.set("x", point.x);
446        ctx.set("y", point.y);
447        ctx.set("r", point.r);
448        ctx.set("g", point.g);
449        ctx.set("b", point.b);
450        ctx.set("a", point.a);
451        for node in compiled {
452            self.eval_compiled(node)?;
453        }
454        let c = &self.context;
455        Ok(WavePoint {
456            sample: point.sample,
457            value1: point.value1,
458            value2: point.value2,
459            x: c.get("x").unwrap_or(point.x),
460            y: c.get("y").unwrap_or(point.y),
461            r: c.get("r").unwrap_or(point.r),
462            g: c.get("g").unwrap_or(point.g),
463            b: c.get("b").unwrap_or(point.b),
464            a: c.get("a").unwrap_or(point.a),
465        })
466    }
467
468    /// Parse an assignment expression and update context.
469    /// Returns the assigned value.
470    pub fn eval_assignment(&mut self, expression: &str) -> Result<f64> {
471        let result = self.eval(expression)?;
472
473        if let Some((var_name, _)) = expression.split_once('=') {
474            let var_name = var_name.trim();
475            self.context.set_var(var_name, result);
476        }
477
478        Ok(result)
479    }
480
481    /// Reset the evaluator to initial state.
482    pub fn reset(&mut self) {
483        self.context = MilkContext::new();
484        self.compiled_cache.clear();
485    }
486}
487
488impl Default for MilkEvaluator {
489    fn default() -> Self {
490        Self::new()
491    }
492}
493
494// ---------- Shared byte-walking utilities ----------
495//
496// These helpers walk paren-balanced byte slices and are used by every
497// submodule (preprocess, rewriters, gmegabuf, joiner). They live here so
498// each submodule pulls them from `super::*` without going through a
499// separate `utils` module.
500
501/// `true` when `b` is an ASCII identifier byte (`[A-Za-z0-9_]`).
502pub(super) fn is_ident_byte(b: u8) -> bool {
503    b.is_ascii_alphanumeric() || b == b'_'
504}
505
506/// Find the matching `)` for the `(` at `open`. Returns the byte offset
507/// of the closer or `None` if unbalanced.
508pub(super) fn match_close_paren(bytes: &[u8], open: usize) -> Option<usize> {
509    debug_assert_eq!(bytes[open], b'(');
510    let mut depth = 1usize;
511    let mut k = open + 1;
512    while k < bytes.len() {
513        match bytes[k] {
514            b'(' => depth += 1,
515            b')' => {
516                depth -= 1;
517                if depth == 0 {
518                    return Some(k);
519                }
520            }
521            _ => {}
522        }
523        k += 1;
524    }
525    None
526}
527
528/// `true` if `s` contains a comparison operator (`> < >= <= == !=`) at
529/// paren depth 0.
530pub(super) fn contains_top_level_comparison(s: &str) -> bool {
531    let bytes = s.as_bytes();
532    let mut depth = 0i32;
533    let mut i = 0;
534    while i < bytes.len() {
535        match bytes[i] {
536            b'(' => depth += 1,
537            b')' => depth -= 1,
538            b'>' | b'<' if depth == 0 => return true,
539            b'=' | b'!' if depth == 0 && bytes.get(i + 1).copied() == Some(b'=') => return true,
540            _ => {}
541        }
542        i += 1;
543    }
544    false
545}
546
547/// `true` if `s` contains a `,` at paren depth 0.
548pub(super) fn contains_top_level_comma(s: &str) -> bool {
549    contains_top_level_byte(s, b',')
550}
551
552/// Generic "byte at depth 0" check.
553pub(super) fn contains_top_level_byte(s: &str, target: u8) -> bool {
554    let mut depth = 0i32;
555    for b in s.bytes() {
556        match b {
557            b'(' => depth += 1,
558            b')' => depth -= 1,
559            x if x == target && depth == 0 => return true,
560            _ => {}
561        }
562    }
563    false
564}
565
566/// Generic "split on byte at depth 0".
567pub(super) fn split_top_level_byte(s: &str, target: u8) -> Vec<&str> {
568    let bytes = s.as_bytes();
569    let mut parts = Vec::new();
570    let mut depth = 0i32;
571    let mut start = 0usize;
572    for (i, b) in bytes.iter().enumerate() {
573        match *b {
574            b'(' => depth += 1,
575            b')' => depth -= 1,
576            x if x == target && depth == 0 => {
577                parts.push(&s[start..i]);
578                start = i + 1;
579            }
580            _ => {}
581        }
582    }
583    parts.push(&s[start..]);
584    parts
585}
586
587/// Split `inner` on each top-level `,` (commas at paren depth 0). Always
588/// returns at least one element (the whole `inner` if no top-level comma).
589pub(super) fn split_top_level_commas(inner: &str) -> Vec<&str> {
590    let bytes = inner.as_bytes();
591    let mut parts = Vec::new();
592    let mut depth = 0usize;
593    let mut start = 0usize;
594    for (i, b) in bytes.iter().enumerate() {
595        match *b {
596            b'(' => depth += 1,
597            b')' => depth = depth.saturating_sub(1),
598            b',' if depth == 0 => {
599                parts.push(&inner[start..i]);
600                start = i + 1;
601            }
602            _ => {}
603        }
604    }
605    parts.push(&inner[start..]);
606    parts
607}
608
609/// `true` if the `(` at byte offset `paren_pos` opens a function-call arg
610/// list (preceded by an identifier or digit, modulo whitespace) rather
611/// than a standalone grouping expression.
612pub(super) fn is_paren_func_call_open(bytes: &[u8], paren_pos: usize) -> bool {
613    let mut j = paren_pos;
614    while j > 0 && matches!(bytes[j - 1], b' ' | b'\t') {
615        j -= 1;
616    }
617    j > 0 && is_ident_byte(bytes[j - 1])
618}
619
620/// `true` if `name` is a registered builtin identifier we should leave
621/// alone in the auto-init pass. Centralised so the preprocess and
622/// case-normalising helpers stay in sync.
623pub(crate) fn is_builtin_ident(name: &str) -> bool {
624    matches!(
625        name,
626        "sin"
627            | "cos"
628            | "tan"
629            | "sqrt"
630            | "abs"
631            | "pow"
632            | "exp"
633            | "log"
634            | "log10"
635            | "ln"
636            | "if"
637            | "milkif"
638            | "min"
639            | "max"
640            | "floor"
641            | "ceil"
642            | "round"
643            | "rand"
644            | "above"
645            | "below"
646            | "equal"
647            | "bnot"
648            | "band"
649            | "bor"
650            | "int"
651            | "fmod"
652            | "clamp"
653            | "sinh"
654            | "cosh"
655            | "tanh"
656            | "asin"
657            | "acos"
658            | "atan"
659            | "atan2"
660            | "sqr"
661            | "rad"
662            | "deg"
663            | "fract"
664            | "trunc"
665            | "sign"
666            | "sigmoid"
667            | "loop"
668            | "while"
669            | "exec2"
670            | "exec3"
671            | "gmegabuf"
672            | "megabuf"
673    )
674}