onedrop_eval/
bytecode.rs

1//! Stack-machine bytecode for MD2 per-point / per-pixel hot loops.
2//!
3//! ## Why
4//!
5//! evalexpr evaluates a compiled `Node` by recursively walking the
6//! operator tree. Each visit allocates a `Vec<Value>` to gather child
7//! results, clones `Value`s through the call stack, and looks up
8//! identifiers by string. On a per-point block iterated 512 times per
9//! wave times 4 waves per frame, that overhead dominates the actual
10//! arithmetic.
11//!
12//! This module translates the AST once at preset load into a flat
13//! `Vec<Op>` and runs it through a tight interpreter loop:
14//!
15//! * **Hot vars** (`x`, `y`, `r`, `g`, `b`, `a`, …) become single
16//!   `LoadHot(i)` / `StoreHot(i)` opcodes that hit the array slot
17//!   directly — no name match, no Value cloning.
18//! * **q-channels** (`q1`..`q64`) use `LoadQ(i)` / `StoreQ(i)` for the
19//!   same reason.
20//! * **Math functions** (`sin`, `sqrt`, `pow`, `if`, …) lower to
21//!   dedicated opcodes that pop arguments off the f64 stack and push
22//!   the result — no `Function::call` dispatch, no `Value::Tuple`
23//!   allocation per call.
24//! * **Cold reads/writes** (custom vars, per-frame builtins like
25//!   `bass`) fall back to `MilkContext::get` / `set`, which still
26//!   beats evalexpr's clone-and-dispatch by skipping the tree walk.
27//!
28//! ## Scope
29//!
30//! The compiler bails out (`Err(CompileError::Unsupported)`) on any
31//! operator or function it doesn't recognise. Callers must keep the
32//! original evalexpr `Vec<Node>` around and use it as the fallback.
33//! What's supported today covers the bulk of corpus per_point /
34//! per_pixel blocks: arithmetic, comparisons, conditionals, the
35//! standard math library, and reads/writes of hot + q channels.
36//!
37//! Explicitly NOT supported (caller must fall back to evalexpr):
38//!
39//! * `loop()` / `while()` / `exec2()` / `exec3()` — the evaluator's
40//!   `eval_processed_with_loops` rewrites these structurally; they
41//!   never reach this compiler.
42//! * `gmegabuf` / `megabuf` / `gmegabuf_set` / `megabuf_set` —
43//!   thread-local persistent state we don't expose in the VM.
44//! * `rand` — relies on a thread-local RNG; sample-order dependent.
45//! * String values and tuples (other than as function-argument
46//!   wrappers).
47
48use crate::context::{MilkContext, hot_index_of, q_index_of};
49use evalexpr::{Node, Operator, Value};
50
51const STACK_SIZE: usize = 256;
52
53/// Stack-machine opcodes. The `u8` / `u32` payloads stay inline so the
54/// `Vec<Op>` is a flat array the interpreter scans linearly.
55#[derive(Debug, Clone, Copy)]
56pub enum Op {
57    /// Push `consts[idx]` onto the stack.
58    PushConst(u32),
59    /// Drop the top of stack.
60    Pop,
61
62    /// Push the value of hot slot `idx` (one of the `HOT_*` constants).
63    LoadHot(u8),
64    /// Pop and write into hot slot `idx`.
65    StoreHot(u8),
66
67    /// Push q-channel `idx` (0-based: `LoadQ(0)` → `q1`).
68    LoadQ(u8),
69    /// Pop and write into q-channel `idx`.
70    StoreQ(u8),
71
72    /// Push `ctx.get(cold_names[idx]).unwrap_or(0.0)`.
73    LoadCold(u32),
74    /// Pop and call `ctx.set(cold_names[idx], v)`.
75    StoreCold(u32),
76
77    // Arithmetic (binary unless noted)
78    Add,
79    Sub,
80    Mul,
81    Div,
82    Mod,
83    Pow,
84    Neg,
85
86    // Comparisons (push 1.0 / 0.0)
87    Eq,
88    Neq,
89    Gt,
90    Lt,
91    Geq,
92    Leq,
93
94    // Logical (treat operand as bool via `!= 0.0`; push 1.0 / 0.0)
95    And,
96    Or,
97    Not,
98
99    // Math 1-arg
100    Sin,
101    Cos,
102    Tan,
103    Asin,
104    Acos,
105    Atan,
106    Sqrt,
107    Exp,
108    Log,
109    Log10,
110    Ln,
111    AbsF,
112    SignF,
113    Floor,
114    Ceil,
115    Round,
116    Fract,
117    Trunc,
118    Sinh,
119    Cosh,
120    Tanh,
121    Sqr,
122    ToRad,
123    ToDeg,
124    Int,
125    Bnot,
126
127    // Math 2-arg
128    Atan2,
129    FmodF,
130    MinF,
131    MaxF,
132    Above,
133    Below,
134    Equal,
135    Band,
136    Bor,
137    Sigmoid,
138
139    // Math 3-arg
140    ClampF,
141    /// MD2 `if(cond, then, else)` (and `milkif`) — `if` returns `then`
142    /// when `cond != 0`, else `else`. The bytecode evaluates *both*
143    /// branches and selects — same as evalexpr's builtin `if` does.
144    IfSelect,
145
146    // --- Stateful builtins ---
147    /// `rand(max)` — pseudo-random value in `[0, max)` seeded by the
148    /// system clock (matches the evalexpr-registered `rand`).
149    Rand,
150    /// `gmegabuf(idx)` — read slot `idx` from the thread-local gmegabuf.
151    /// Out-of-range / non-finite indices push `0.0`.
152    Gmegabuf,
153    /// `megabuf(idx)` — same shape as `Gmegabuf` against the megabuf.
154    Megabuf,
155    /// `gmegabuf_set(idx, val)` — write `val` to slot `idx`. Pushes
156    /// `0.0` (evalexpr returns `Empty` on assignments; the rest of the
157    /// VM treats that as a numeric `0.0`).
158    GmegabufSet,
159    /// `megabuf_set(idx, val)` — same shape as `GmegabufSet` against
160    /// the megabuf.
161    MegabufSet,
162}
163
164/// A reason the bytecode compiler couldn't lower a given AST node.
165/// The caller falls back to evalexpr's `Node` evaluation in that case.
166#[derive(Debug, Clone, PartialEq, Eq)]
167pub enum CompileError {
168    /// An operator or function we don't lower yet. The string carries
169    /// the human-readable name for debugging.
170    Unsupported(String),
171    /// The AST shape was malformed (e.g., an `Assign` with the wrong
172    /// number of children) — likely a bug or unsupported syntax that
173    /// reached evalexpr's parser but tripped our walker.
174    Malformed(String),
175}
176
177impl std::fmt::Display for CompileError {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        match self {
180            Self::Unsupported(s) => write!(f, "unsupported in bytecode VM: {s}"),
181            Self::Malformed(s) => write!(f, "malformed AST: {s}"),
182        }
183    }
184}
185
186impl std::error::Error for CompileError {}
187
188/// A bytecode-compiled per_point / per_pixel block. Constructed by
189/// [`CompiledBytecode::try_compile`] and replayed by
190/// [`CompiledBytecode::run`].
191///
192/// Cold variables (custom user vars + cold builtins) are addressed by
193/// their slot index in the [`MilkContext::cold`] slab — resolved once
194/// at compile time so `Op::LoadCold(idx)` / `Op::StoreCold(idx)` skip
195/// the `HashMap` probe and `String::from` allocation the older
196/// name-keyed path paid on every store.
197#[derive(Debug, Clone)]
198pub struct CompiledBytecode {
199    code: Vec<Op>,
200    consts: Vec<f64>,
201}
202
203impl CompiledBytecode {
204    /// Translate a block of evalexpr `Node`s into bytecode against the
205    /// supplied context. Cold-variable names referenced by the block
206    /// are interned into [`MilkContext::cold`] so the emitted opcodes
207    /// carry slab indices instead of string keys.
208    ///
209    /// Returns `Err(Unsupported)` on the first operator / function the
210    /// compiler doesn't lower — the caller can fall back to evaluating
211    /// the `Vec<Node>` through evalexpr directly.
212    pub fn try_compile(nodes: &[Node], ctx: &mut MilkContext) -> Result<Self, CompileError> {
213        let mut c = Compiler::new(ctx);
214        for node in nodes {
215            c.compile_node(node)?;
216            // Each top-level statement leaves exactly one value on the
217            // stack. Drop it: per_point / per_pixel discards
218            // intermediate results.
219            c.emit(Op::Pop);
220        }
221        Ok(c.finish())
222    }
223
224    /// Run the bytecode against `ctx`. Mutates the context in place
225    /// (hot vars, q-channels, custom vars).
226    ///
227    /// # Safety footprint
228    ///
229    /// The inner loop uses `get_unchecked` to elide bounds checks on
230    /// `stack`, `consts`, and `code`. The invariants the bytecode
231    /// compiler upholds are:
232    ///
233    /// - `sp <= STACK_SIZE` at every opcode boundary — checked by
234    ///   `debug_assert!` on each push.
235    /// - `i < consts.len()` for every `PushConst(i)` — `add_const`
236    ///   returns the index it just pushed, so the bytecode never
237    ///   emits a stale index.
238    /// - `pc < code.len()` — the `while` header guards each fetch.
239    ///
240    /// On the bench-presets corpus this shaves a few percent off
241    /// per-sample VM time; on a tight inner loop with `Op::Add` /
242    /// `Op::Mul`, every bounds check elided counts.
243    pub fn run(&self, ctx: &mut MilkContext) {
244        let mut stack = [0.0f64; STACK_SIZE];
245        let mut sp: usize = 0;
246
247        let code = self.code.as_slice();
248        let consts = self.consts.as_slice();
249        let code_len = code.len();
250
251        let mut pc = 0;
252        // SAFETY: see method-level "Safety footprint" — the bytecode
253        // compiler emits well-formed indices and the stack depth never
254        // exceeds `STACK_SIZE` on any input that round-trips through
255        // `Compiler::compile_node` (verified by `debug_assert` on
256        // pushes).
257        unsafe {
258            while pc < code_len {
259                let op = *code.get_unchecked(pc);
260                pc += 1;
261                match op {
262                    Op::PushConst(i) => {
263                        debug_assert!(sp < STACK_SIZE);
264                        *stack.get_unchecked_mut(sp) = *consts.get_unchecked(i as usize);
265                        sp += 1;
266                    }
267                    Op::Pop => {
268                        sp -= 1;
269                    }
270                    Op::LoadHot(i) => {
271                        debug_assert!(sp < STACK_SIZE);
272                        *stack.get_unchecked_mut(sp) = ctx.hot_get_idx(i as usize);
273                        sp += 1;
274                    }
275                    Op::StoreHot(i) => {
276                        sp -= 1;
277                        ctx.hot_set_idx(i as usize, *stack.get_unchecked(sp));
278                    }
279                    Op::LoadQ(i) => {
280                        debug_assert!(sp < STACK_SIZE);
281                        *stack.get_unchecked_mut(sp) = ctx.q_get_idx(i as usize);
282                        sp += 1;
283                    }
284                    Op::StoreQ(i) => {
285                        sp -= 1;
286                        ctx.q_set_idx(i as usize, *stack.get_unchecked(sp));
287                    }
288                    Op::LoadCold(i) => {
289                        debug_assert!(sp < STACK_SIZE);
290                        *stack.get_unchecked_mut(sp) = ctx.cold_get_idx(i as usize);
291                        sp += 1;
292                    }
293                    Op::StoreCold(i) => {
294                        sp -= 1;
295                        ctx.cold_set_idx(i as usize, *stack.get_unchecked(sp));
296                    }
297
298                    // --- Binary arithmetic ---
299                    Op::Add => bin_op(&mut stack, &mut sp, |a, b| a + b),
300                    Op::Sub => bin_op(&mut stack, &mut sp, |a, b| a - b),
301                    Op::Mul => bin_op(&mut stack, &mut sp, |a, b| a * b),
302                    Op::Div => bin_op(&mut stack, &mut sp, |a, b| a / b),
303                    Op::Mod => bin_op(&mut stack, &mut sp, |a, b| a % b),
304                    Op::Pow => bin_op(&mut stack, &mut sp, f64::powf),
305                    Op::Neg => {
306                        let v = *stack.get_unchecked(sp - 1);
307                        *stack.get_unchecked_mut(sp - 1) = -v;
308                    }
309
310                    // --- Comparisons ---
311                    Op::Eq => bin_op(&mut stack, &mut sp, |a, b| if a == b { 1.0 } else { 0.0 }),
312                    Op::Neq => bin_op(&mut stack, &mut sp, |a, b| if a != b { 1.0 } else { 0.0 }),
313                    Op::Gt => bin_op(&mut stack, &mut sp, |a, b| if a > b { 1.0 } else { 0.0 }),
314                    Op::Lt => bin_op(&mut stack, &mut sp, |a, b| if a < b { 1.0 } else { 0.0 }),
315                    Op::Geq => bin_op(&mut stack, &mut sp, |a, b| if a >= b { 1.0 } else { 0.0 }),
316                    Op::Leq => bin_op(&mut stack, &mut sp, |a, b| if a <= b { 1.0 } else { 0.0 }),
317
318                    // --- Logical (true iff non-zero) ---
319                    Op::And => bin_op(&mut stack, &mut sp, |a, b| {
320                        if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 }
321                    }),
322                    Op::Or => bin_op(&mut stack, &mut sp, |a, b| {
323                        if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 }
324                    }),
325                    Op::Not => {
326                        let v = *stack.get_unchecked(sp - 1);
327                        *stack.get_unchecked_mut(sp - 1) = if v == 0.0 { 1.0 } else { 0.0 };
328                    }
329
330                    // --- Math 1-arg ---
331                    Op::Sin => unary(&mut stack, sp, f64::sin),
332                    Op::Cos => unary(&mut stack, sp, f64::cos),
333                    Op::Tan => unary(&mut stack, sp, f64::tan),
334                    Op::Asin => unary(&mut stack, sp, f64::asin),
335                    Op::Acos => unary(&mut stack, sp, f64::acos),
336                    Op::Atan => unary(&mut stack, sp, f64::atan),
337                    Op::Sqrt => unary(&mut stack, sp, f64::sqrt),
338                    Op::Exp => unary(&mut stack, sp, f64::exp),
339                    Op::Log => unary(&mut stack, sp, f64::ln),
340                    Op::Log10 => unary(&mut stack, sp, f64::log10),
341                    Op::Ln => unary(&mut stack, sp, f64::ln),
342                    Op::AbsF => unary(&mut stack, sp, f64::abs),
343                    Op::SignF => unary(&mut stack, sp, |x| x.signum()),
344                    Op::Floor => unary(&mut stack, sp, f64::floor),
345                    Op::Ceil => unary(&mut stack, sp, f64::ceil),
346                    Op::Round => unary(&mut stack, sp, f64::round),
347                    Op::Fract => unary(&mut stack, sp, f64::fract),
348                    Op::Trunc => unary(&mut stack, sp, f64::trunc),
349                    Op::Sinh => unary(&mut stack, sp, f64::sinh),
350                    Op::Cosh => unary(&mut stack, sp, f64::cosh),
351                    Op::Tanh => unary(&mut stack, sp, f64::tanh),
352                    Op::Sqr => unary(&mut stack, sp, |x| x * x),
353                    Op::ToRad => unary(&mut stack, sp, f64::to_radians),
354                    Op::ToDeg => unary(&mut stack, sp, f64::to_degrees),
355                    Op::Int => unary(&mut stack, sp, f64::trunc),
356                    Op::Bnot => unary(&mut stack, sp, |x| if x == 0.0 { 1.0 } else { 0.0 }),
357
358                    // --- Math 2-arg ---
359                    Op::Atan2 => bin_op(&mut stack, &mut sp, f64::atan2),
360                    Op::FmodF => bin_op(&mut stack, &mut sp, |a, b| a % b),
361                    Op::MinF => bin_op(&mut stack, &mut sp, f64::min),
362                    Op::MaxF => bin_op(&mut stack, &mut sp, f64::max),
363                    Op::Above => bin_op(&mut stack, &mut sp, |a, b| if a > b { 1.0 } else { 0.0 }),
364                    Op::Below => bin_op(&mut stack, &mut sp, |a, b| if a < b { 1.0 } else { 0.0 }),
365                    Op::Equal => bin_op(&mut stack, &mut sp, |a, b| if a == b { 1.0 } else { 0.0 }),
366                    Op::Band => bin_op(&mut stack, &mut sp, |a, b| {
367                        if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 }
368                    }),
369                    Op::Bor => bin_op(&mut stack, &mut sp, |a, b| {
370                        if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 }
371                    }),
372                    Op::Sigmoid => bin_op(&mut stack, &mut sp, |x, center| {
373                        // Same shape as `math_functions::sigmoid`.
374                        1.0 / (1.0 + (-(x - center)).exp())
375                    }),
376
377                    // --- Math 3-arg ---
378                    Op::ClampF => {
379                        // stack: ..., x, lo, hi
380                        sp -= 3;
381                        let x = *stack.get_unchecked(sp);
382                        let lo = *stack.get_unchecked(sp + 1);
383                        let hi = *stack.get_unchecked(sp + 2);
384                        *stack.get_unchecked_mut(sp) = x.clamp(lo, hi);
385                        sp += 1;
386                    }
387                    Op::IfSelect => {
388                        // stack: ..., cond, then, else
389                        sp -= 3;
390                        let cond = *stack.get_unchecked(sp);
391                        let then_v = *stack.get_unchecked(sp + 1);
392                        let else_v = *stack.get_unchecked(sp + 2);
393                        *stack.get_unchecked_mut(sp) = if cond != 0.0 { then_v } else { else_v };
394                        sp += 1;
395                    }
396
397                    // --- Stateful builtins ---
398                    Op::Rand => {
399                        // `rand(max)` matches `math_functions::rand`: nanosecond
400                        // clock × max / 1e6. Same pseudo-randomness profile, no
401                        // RNG state in the VM.
402                        use std::time::{SystemTime, UNIX_EPOCH};
403                        let max = *stack.get_unchecked(sp - 1);
404                        let seed = SystemTime::now()
405                            .duration_since(UNIX_EPOCH)
406                            .map(|d| d.as_nanos())
407                            .unwrap_or(0);
408                        *stack.get_unchecked_mut(sp - 1) =
409                            ((seed % 1_000_000) as f64 / 1_000_000.0) * max;
410                    }
411                    Op::Gmegabuf | Op::Megabuf => {
412                        // `megabuf` and `gmegabuf` share the same backing in
413                        // this crate (see `math_functions::gmegabuf` doc).
414                        let idx = *stack.get_unchecked(sp - 1);
415                        *stack.get_unchecked_mut(sp - 1) =
416                            crate::math_functions::gmegabuf::read(idx);
417                    }
418                    Op::GmegabufSet | Op::MegabufSet => {
419                        // `<name>_set(idx, val)` evaluates to `val` (matches
420                        // the evalexpr-registered functions).
421                        sp -= 2;
422                        let idx = *stack.get_unchecked(sp);
423                        let val = *stack.get_unchecked(sp + 1);
424                        crate::math_functions::gmegabuf::write(idx, val);
425                        *stack.get_unchecked_mut(sp) = val;
426                        sp += 1;
427                    }
428                }
429            }
430        } // unsafe
431    }
432}
433
434#[inline(always)]
435fn bin_op(stack: &mut [f64; STACK_SIZE], sp: &mut usize, f: impl FnOnce(f64, f64) -> f64) {
436    // SAFETY: the bytecode compiler emits a `bin_op` only after pushing
437    // its two operands, so `*sp >= 2` here. `*sp - 1` and `*sp - 2` are
438    // in-range slots; the stack is fixed at `STACK_SIZE` so the writes
439    // back can't overflow.
440    *sp -= 1;
441    debug_assert!(*sp >= 1 && *sp < STACK_SIZE);
442    unsafe {
443        let b = *stack.get_unchecked(*sp);
444        let a = *stack.get_unchecked(*sp - 1);
445        *stack.get_unchecked_mut(*sp - 1) = f(a, b);
446    }
447}
448
449/// Strip evalexpr's parenthesis-wrapper `RootNode`s. Every group like
450/// `(expr)` becomes a single-child RootNode in the tree, so descending
451/// into function args means peeling them one layer at a time until we
452/// reach a real operator.
453fn unwrap_root(node: &Node) -> &Node {
454    let mut n = node;
455    while matches!(n.operator(), Operator::RootNode) && n.children().len() == 1 {
456        n = &n.children()[0];
457    }
458    n
459}
460
461#[inline(always)]
462fn unary(stack: &mut [f64; STACK_SIZE], sp: usize, f: impl FnOnce(f64) -> f64) {
463    // SAFETY: a `unary` is emitted after the operand push, so `sp >= 1`.
464    debug_assert!((1..=STACK_SIZE).contains(&sp));
465    unsafe {
466        let v = *stack.get_unchecked(sp - 1);
467        *stack.get_unchecked_mut(sp - 1) = f(v);
468    }
469}
470
471struct Compiler<'a> {
472    code: Vec<Op>,
473    consts: Vec<f64>,
474    /// Borrowed context — used to intern cold-variable names into the
475    /// slab so each `LoadCold`/`StoreCold` opcode carries a slab index
476    /// rather than a name.
477    ctx: &'a mut MilkContext,
478}
479
480impl<'a> Compiler<'a> {
481    fn new(ctx: &'a mut MilkContext) -> Self {
482        Self {
483            code: Vec::new(),
484            consts: Vec::new(),
485            ctx,
486        }
487    }
488
489    fn emit(&mut self, op: Op) {
490        self.code.push(op);
491    }
492
493    fn add_const(&mut self, f: f64) -> u32 {
494        // Linear scan to dedup. Per_point blocks have a handful of
495        // constants — the scan is cheaper than a HashMap probe + hash
496        // (`f64::to_bits` would be needed since `f64: !Eq`).
497        for (i, &c) in self.consts.iter().enumerate() {
498            if c.to_bits() == f.to_bits() {
499                return i as u32;
500            }
501        }
502        let idx = self.consts.len() as u32;
503        self.consts.push(f);
504        idx
505    }
506
507    /// Intern a cold-variable name into [`MilkContext::cold`] and return
508    /// its slab index. Subsequent references to the same name reuse the
509    /// existing slot; auto-seeded MD2 defaults (`bass`, `zoom`, …) hit
510    /// pre-existing slots without growing the slab.
511    fn add_cold(&mut self, name: &str) -> u32 {
512        self.ctx.cold_intern(name) as u32
513    }
514
515    fn finish(self) -> CompiledBytecode {
516        CompiledBytecode {
517            code: self.code,
518            consts: self.consts,
519        }
520    }
521
522    /// Compile a node so that it leaves exactly one f64 on the stack.
523    fn compile_node(&mut self, node: &Node) -> Result<(), CompileError> {
524        use Operator::*;
525        let op = node.operator();
526        let children = node.children();
527        match op {
528            RootNode => {
529                // RootNode evaluates its first child; the rest are
530                // discarded by evalexpr too.
531                if let Some(first) = children.first() {
532                    self.compile_node(first)?;
533                } else {
534                    let i = self.add_const(0.0);
535                    self.emit(Op::PushConst(i));
536                }
537            }
538            Chain => {
539                // `a; b; c` — eval each, discard all but the last.
540                if children.is_empty() {
541                    let i = self.add_const(0.0);
542                    self.emit(Op::PushConst(i));
543                } else {
544                    let last = children.len() - 1;
545                    for (i, child) in children.iter().enumerate() {
546                        self.compile_node(child)?;
547                        if i != last {
548                            self.emit(Op::Pop);
549                        }
550                    }
551                }
552            }
553            Const { value } => {
554                let f = match value {
555                    Value::Float(f) => *f,
556                    Value::Int(i) => *i as f64,
557                    Value::Boolean(b) => {
558                        if *b {
559                            1.0
560                        } else {
561                            0.0
562                        }
563                    }
564                    other => {
565                        return Err(CompileError::Unsupported(format!("Const({other:?})")));
566                    }
567                };
568                let i = self.add_const(f);
569                self.emit(Op::PushConst(i));
570            }
571            VariableIdentifierRead { identifier } => self.emit_read(identifier),
572            VariableIdentifierWrite { .. } => {
573                // Bare write nodes only appear as the LHS of an
574                // Assign-family parent; reaching one standalone means
575                // we tripped on an unsupported AST shape.
576                return Err(CompileError::Malformed(
577                    "bare VariableIdentifierWrite".into(),
578                ));
579            }
580            Assign => {
581                self.compile_assign(children, None)?;
582            }
583            AddAssign => self.compile_assign(children, Some(Op::Add))?,
584            SubAssign => self.compile_assign(children, Some(Op::Sub))?,
585            MulAssign => self.compile_assign(children, Some(Op::Mul))?,
586            DivAssign => self.compile_assign(children, Some(Op::Div))?,
587            ModAssign => self.compile_assign(children, Some(Op::Mod))?,
588            ExpAssign => self.compile_assign(children, Some(Op::Pow))?,
589            AndAssign => self.compile_assign(children, Some(Op::And))?,
590            OrAssign => self.compile_assign(children, Some(Op::Or))?,
591
592            Add => self.compile_binary(children, Op::Add)?,
593            Sub => self.compile_binary(children, Op::Sub)?,
594            Mul => self.compile_binary(children, Op::Mul)?,
595            Div => self.compile_binary(children, Op::Div)?,
596            Mod => self.compile_binary(children, Op::Mod)?,
597            Exp => self.compile_binary(children, Op::Pow)?,
598            Neg => self.compile_unary(children, Op::Neg)?,
599            Eq => self.compile_binary(children, Op::Eq)?,
600            Neq => self.compile_binary(children, Op::Neq)?,
601            Gt => self.compile_binary(children, Op::Gt)?,
602            Lt => self.compile_binary(children, Op::Lt)?,
603            Geq => self.compile_binary(children, Op::Geq)?,
604            Leq => self.compile_binary(children, Op::Leq)?,
605            And => self.compile_binary(children, Op::And)?,
606            Or => self.compile_binary(children, Op::Or)?,
607            Not => self.compile_unary(children, Op::Not)?,
608
609            Tuple => {
610                // Standalone tuples have no f64 representation. The
611                // function-call path consumes a Tuple child directly.
612                return Err(CompileError::Unsupported("standalone Tuple".into()));
613            }
614            FunctionIdentifier { identifier } => self.compile_call(identifier, children)?,
615        }
616        Ok(())
617    }
618
619    /// Emit either `LoadHot(i)`, `LoadQ(i)`, or `LoadCold(name_idx)`
620    /// depending on which class the identifier falls into.
621    fn emit_read(&mut self, name: &str) {
622        if let Some(idx) = hot_index_of(name) {
623            self.emit(Op::LoadHot(idx as u8));
624        } else if let Some(idx) = q_index_of(name) {
625            self.emit(Op::LoadQ(idx as u8));
626        } else {
627            let i = self.add_cold(name);
628            self.emit(Op::LoadCold(i));
629        }
630    }
631
632    fn emit_write(&mut self, name: &str) {
633        if let Some(idx) = hot_index_of(name) {
634            self.emit(Op::StoreHot(idx as u8));
635        } else if let Some(idx) = q_index_of(name) {
636            self.emit(Op::StoreQ(idx as u8));
637        } else {
638            let i = self.add_cold(name);
639            self.emit(Op::StoreCold(i));
640        }
641    }
642
643    fn compile_unary(&mut self, children: &[Node], op: Op) -> Result<(), CompileError> {
644        if children.len() != 1 {
645            return Err(CompileError::Malformed(format!(
646                "unary op with {} children",
647                children.len()
648            )));
649        }
650        self.compile_node(&children[0])?;
651        self.emit(op);
652        Ok(())
653    }
654
655    fn compile_binary(&mut self, children: &[Node], op: Op) -> Result<(), CompileError> {
656        if children.len() != 2 {
657            return Err(CompileError::Malformed(format!(
658                "binary op with {} children",
659                children.len()
660            )));
661        }
662        self.compile_node(&children[0])?;
663        self.compile_node(&children[1])?;
664        self.emit(op);
665        Ok(())
666    }
667
668    /// Compile `x = expr` (when `compound` is `None`) or `x op= expr`
669    /// (when `compound` is `Some(arith_op)`). In both cases the
670    /// statement leaves a placeholder 0.0 on the stack so the outer
671    /// statement-discard `Pop` stays balanced.
672    fn compile_assign(
673        &mut self,
674        children: &[Node],
675        compound: Option<Op>,
676    ) -> Result<(), CompileError> {
677        if children.len() != 2 {
678            return Err(CompileError::Malformed(format!(
679                "Assign with {} children",
680                children.len()
681            )));
682        }
683        let name = match children[0].operator() {
684            Operator::VariableIdentifierWrite { identifier } => identifier.clone(),
685            other => {
686                return Err(CompileError::Unsupported(format!("Assign LHS: {other:?}")));
687            }
688        };
689
690        if let Some(arith) = compound {
691            // Read current value first (LHS load comes before RHS so
692            // the stack ends up [..., lhs, rhs] in order for the binop).
693            self.emit_read(&name);
694            self.compile_node(&children[1])?;
695            self.emit(arith);
696        } else {
697            self.compile_node(&children[1])?;
698        }
699        self.emit_write(&name);
700
701        // Assign returns `Value::Empty` in evalexpr — mimic with 0.0.
702        let i = self.add_const(0.0);
703        self.emit(Op::PushConst(i));
704        Ok(())
705    }
706
707    /// Lower a function call. evalexpr wraps every parenthesised
708    /// subexpression in a `RootNode`, so a call like `pow(2, 5)`
709    /// lands as `FunctionIdentifier → RootNode → Tuple → [RootNode→2,
710    /// RootNode→5]`. We unwrap the outer `RootNode`, then either split
711    /// a `Tuple` into args or treat the whole thing as a single arg.
712    fn compile_call(&mut self, name: &str, children: &[Node]) -> Result<(), CompileError> {
713        if children.len() != 1 {
714            return Err(CompileError::Malformed(format!(
715                "Function `{}` with {} top children",
716                name,
717                children.len()
718            )));
719        }
720        let arg_node = unwrap_root(&children[0]);
721        let args: Vec<&Node> = match arg_node.operator() {
722            Operator::Tuple => arg_node.children().iter().map(unwrap_root).collect(),
723            _ => vec![arg_node],
724        };
725
726        let (op, arity) = match name {
727            // 1-arg math
728            "sin" => (Op::Sin, 1),
729            "cos" => (Op::Cos, 1),
730            "tan" => (Op::Tan, 1),
731            "asin" => (Op::Asin, 1),
732            "acos" => (Op::Acos, 1),
733            "atan" => (Op::Atan, 1),
734            "sqrt" => (Op::Sqrt, 1),
735            "exp" => (Op::Exp, 1),
736            "log" => (Op::Log, 1),
737            "ln" => (Op::Ln, 1),
738            "log10" => (Op::Log10, 1),
739            "abs" => (Op::AbsF, 1),
740            "sign" => (Op::SignF, 1),
741            "floor" => (Op::Floor, 1),
742            "ceil" => (Op::Ceil, 1),
743            "round" => (Op::Round, 1),
744            "fract" => (Op::Fract, 1),
745            "trunc" => (Op::Trunc, 1),
746            "sinh" => (Op::Sinh, 1),
747            "cosh" => (Op::Cosh, 1),
748            "tanh" => (Op::Tanh, 1),
749            "sqr" => (Op::Sqr, 1),
750            "rad" => (Op::ToRad, 1),
751            "deg" => (Op::ToDeg, 1),
752            "int" => (Op::Int, 1),
753            "bnot" => (Op::Bnot, 1),
754            "rand" => (Op::Rand, 1),
755            "gmegabuf" => (Op::Gmegabuf, 1),
756            "megabuf" => (Op::Megabuf, 1),
757
758            // 2-arg
759            "atan2" => (Op::Atan2, 2),
760            "pow" => (Op::Pow, 2),
761            "fmod" => (Op::FmodF, 2),
762            "min" => (Op::MinF, 2),
763            "max" => (Op::MaxF, 2),
764            "above" => (Op::Above, 2),
765            "below" => (Op::Below, 2),
766            "equal" => (Op::Equal, 2),
767            "band" => (Op::Band, 2),
768            "bor" => (Op::Bor, 2),
769            "sigmoid" => (Op::Sigmoid, 2),
770            "gmegabuf_set" => (Op::GmegabufSet, 2),
771            "megabuf_set" => (Op::MegabufSet, 2),
772
773            // 3-arg
774            "clamp" => (Op::ClampF, 3),
775            "if" | "milkif" => (Op::IfSelect, 3),
776
777            other => {
778                return Err(CompileError::Unsupported(format!("function {other}")));
779            }
780        };
781
782        if args.len() != arity {
783            return Err(CompileError::Malformed(format!(
784                "{} expects {} args, got {}",
785                name,
786                arity,
787                args.len()
788            )));
789        }
790        for a in &args {
791            self.compile_node(a)?;
792        }
793        self.emit(op);
794        Ok(())
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801    use crate::MilkEvaluator;
802
803    fn compile_block(src: &[&str]) -> Result<CompiledBytecode, CompileError> {
804        let mut ev = MilkEvaluator::new();
805        let owned: Vec<String> = src.iter().map(|s| s.to_string()).collect();
806        let nodes = ev.compile_batch(&owned).unwrap();
807        CompiledBytecode::try_compile(&nodes, ev.context_mut())
808    }
809
810    /// Replay the same source through both backends and compare the
811    /// final value of every named variable in `expected`. Used by every
812    /// parity test below.
813    fn parity_assert(eqs: &[&str], seeds: &[(&str, f64)], expected: &[&str]) {
814        // Bytecode path. The block's `LoadCold` / `StoreCold` opcodes
815        // address slab slots interned at compile time, so we must run
816        // it against the same `MilkContext` that was compiled against.
817        let mut ev_bc = MilkEvaluator::new();
818        let owned: Vec<String> = eqs.iter().map(|s| s.to_string()).collect();
819        let nodes_bc = ev_bc.compile_batch(&owned).unwrap();
820        let bc = CompiledBytecode::try_compile(&nodes_bc, ev_bc.context_mut())
821            .expect("bytecode compile");
822        for &(name, v) in seeds {
823            ev_bc.context_mut().set(name, v);
824        }
825        bc.run(ev_bc.context_mut());
826
827        // evalexpr path on a separate evaluator.
828        let mut ev = MilkEvaluator::new();
829        let nodes = ev.compile_batch(&owned).unwrap();
830        for &(name, v) in seeds {
831            ev.context_mut().set(name, v);
832        }
833        for n in &nodes {
834            ev.eval_compiled(n).unwrap();
835        }
836
837        for name in expected {
838            let a = ev_bc.context().get(name).unwrap_or(f64::NAN);
839            let b = ev.context().get(name).unwrap_or(f64::NAN);
840            assert!(
841                (a - b).abs() < 1e-9 || (a.is_nan() && b.is_nan()),
842                "var `{name}` diverges: bytecode={a} evalexpr={b}"
843            );
844        }
845    }
846
847    #[test]
848    fn arithmetic_parity() {
849        // Note: MD2 / evalexpr is strictly typed on assignment — the
850        // RHS must be a Float to land in a Float-init'd hot slot. The
851        // evaluator's `preprocess_expression` promotes bare `x = 1` to
852        // `x = 1.0`, but it does NOT recurse into compound RHS like
853        // `x = 1 + 2 * 3`. Real-world presets always reference at
854        // least one Float variable on the RHS so this is fine; tests
855        // need to keep the same property explicit.
856        parity_assert(
857            &[
858                "x = 1.0 + 2.0 * 3.0",
859                "y = (4.0 - 1.0) / 2.0",
860                "r = 5.0 % 3.0",
861                "g = 2.0 ^ 10.0",
862            ],
863            &[],
864            &["x", "y", "r", "g"],
865        );
866    }
867
868    #[test]
869    fn hot_var_carry_parity() {
870        parity_assert(
871            &["x = x + 0.5", "y = y * 2"],
872            &[("x", 1.0), ("y", 0.25)],
873            &["x", "y"],
874        );
875    }
876
877    #[test]
878    fn q_channel_parity() {
879        parity_assert(
880            &["q1 = 7", "q2 = q1 * 2", "x = q2 + q1"],
881            &[],
882            &["q1", "q2", "x"],
883        );
884    }
885
886    #[test]
887    fn math_function_parity() {
888        parity_assert(
889            &[
890                "x = sin(1.0)",
891                "y = sqrt(2.0)",
892                "r = pow(2.0, 5.0)",
893                "g = atan2(1.0, 1.0)",
894                "b = abs(-3.5)",
895                "a = clamp(7.0, 0.0, 5.0)",
896            ],
897            &[],
898            &["x", "y", "r", "g", "b", "a"],
899        );
900    }
901
902    #[test]
903    fn conditional_parity() {
904        parity_assert(
905            &[
906                "x = if(above(0.5, 0.1), 1.0, 0.0)",
907                "y = if(below(0.5, 0.1), 1.0, 0.0)",
908                "r = if(equal(0.5, 0.5), 7.0, 9.0)",
909            ],
910            &[],
911            &["x", "y", "r"],
912        );
913    }
914
915    #[test]
916    fn compound_assign_parity() {
917        // `x = 10` would land Int; the preprocess promotes the bare
918        // form to `x = 10.0` so use that to match what real preset
919        // sources look like after preprocessing.
920        parity_assert(
921            &["x = 10", "x += 5", "x -= 2", "x *= 3", "x /= 2"],
922            &[],
923            &["x"],
924        );
925    }
926
927    #[test]
928    fn custom_var_parity() {
929        parity_assert(
930            &["myacc = 0", "myacc = myacc + sample", "x = myacc * 2"],
931            &[("sample", 0.25)],
932            &["myacc", "x"],
933        );
934    }
935
936    #[test]
937    fn unsupported_function_bails() {
938        // A built-in we genuinely don't lower yet (no `lerp` opcode).
939        let err = compile_block(&["x = lerp(0, 1, 0.5)"]).unwrap_err();
940        assert!(
941            matches!(err, CompileError::Unsupported(ref s) if s.contains("lerp")),
942            "expected Unsupported(lerp), got {err:?}"
943        );
944    }
945
946    #[test]
947    fn rand_lowers_to_bytecode() {
948        let bc = compile_block(&["x = rand(10)"]).expect("rand should lower");
949        let mut ctx = MilkContext::new();
950        ctx.set("x", 0.0);
951        bc.run(&mut ctx);
952        let x = ctx.get("x").unwrap();
953        assert!((0.0..10.0).contains(&x), "rand(10) out of range: {x}");
954    }
955
956    #[test]
957    fn gmegabuf_round_trips_through_bytecode() {
958        // gmegabuf is read-only; gmegabuf_set writes; both lower to opcodes.
959        let bc = compile_block(&["gmegabuf_set(3, 7.5); x = gmegabuf(3)"])
960            .expect("gmegabuf{,_set} should lower");
961        let mut ctx = MilkContext::new();
962        ctx.set("x", 0.0);
963        bc.run(&mut ctx);
964        assert_eq!(ctx.get("x"), Some(7.5));
965    }
966
967    #[test]
968    fn empty_block_compiles() {
969        let bc = compile_block(&[]).expect("empty block compiles");
970        let mut ctx = MilkContext::new();
971        bc.run(&mut ctx); // no-op, no panic
972    }
973}