onedrop_eval/
carry.rs

1//! Carry-dependency analysis for per-point equation blocks.
2//!
3//! Custom-wave `per_point` blocks run sequentially across the wave's
4//! samples and the convention is that each iteration sees the previous
5//! iteration's `x`/`y`/`r`/`g`/`b`/`a` as input ("trail across samples").
6//! In practice many presets don't actually depend on that carry — they
7//! write `x = …; y = …; r = …;` from the per-sample inputs (`sample`,
8//! `value1`, `value2`) and never read the previous values. In that case
9//! the samples are independent and the engine is free to evaluate them
10//! in parallel, which is a 4-8× win on dense waves.
11//!
12//! This module walks the compiled [`Node`]s once at preset-load time and
13//! classifies each per_point block as either *safe-to-parallelize* or
14//! *carry-dependent*. The check is conservative — a block is only marked
15//! safe if:
16//!
17//! 1. No carry-tracked variable (`x`, `y`, `r`, `g`, `b`, `a`) is read
18//!    before being written in the block sequence (a read inside an
19//!    equation that also writes the var is treated as carry-needing
20//!    unless an earlier equation already wrote it).
21//! 2. No variable outside `{x, y, r, g, b, a, sample, value1, value2}`
22//!    is written. Stray writes to `qN` / custom vars / per-frame vars
23//!    leak state across samples.
24//! 3. No stateful / side-effecting function is called (`rand`,
25//!    `gmegabuf`, `megabuf`, `gmegabuf_set`, `megabuf_set`, `exec2`,
26//!    `exec3`, `loop`, `while`). These either carry per-eval state
27//!    (gmegabuf is thread-local; samples on different workers would
28//!    see split state) or are sample-order dependent (`rand`).
29
30use evalexpr::Node;
31use std::collections::HashSet;
32
33/// Variables that thread state across iterations of `eval_per_point`
34/// (the caller seeds them from the previous sample's output).
35pub const CARRY_VARS: &[&str] = &["x", "y", "r", "g", "b", "a"];
36
37/// Per-sample input variables the caller re-seeds every iteration —
38/// writes to these inside per_point are harmless because they get
39/// overwritten before the next sample runs.
40const INPUT_VARS: &[&str] = &["sample", "value1", "value2"];
41
42/// Functions whose semantics depend on call order or thread-local state.
43/// A per_point block calling any of these cannot be safely sample-parallelised.
44const UNSAFE_FUNCTIONS: &[&str] = &[
45    "rand",
46    "gmegabuf",
47    "megabuf",
48    "gmegabuf_set",
49    "megabuf_set",
50    "exec2",
51    "exec3",
52    "loop",
53    "while",
54];
55
56/// Classification of a per_point block for the parallel-samples pass.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum PerPointParallelism {
59    /// Samples can be evaluated independently — output of sample `S`
60    /// does not feed into sample `S+1`.
61    Safe,
62    /// Sequential evaluation only — at least one of: carry-read on
63    /// `{x,y,r,g,b,a}`, write outside the allowed output set, or call
64    /// to a stateful function.
65    Sequential,
66}
67
68impl PerPointParallelism {
69    /// Returns `true` for [`PerPointParallelism::Safe`].
70    #[inline]
71    pub fn is_safe(self) -> bool {
72        matches!(self, Self::Safe)
73    }
74}
75
76/// Analyse a compiled per_point block. Walks each equation's operator
77/// tree once and tracks a running set of "already-written" identifiers
78/// to distinguish "first read = carry" from "first read after write =
79/// uses fresh value".
80pub fn analyse_per_point(per_point: &[Node]) -> PerPointParallelism {
81    let mut already_written: HashSet<String> = HashSet::new();
82
83    for node in per_point {
84        // Functions first — a single occurrence of `rand` / `gmegabuf` /
85        // … taints the whole block.
86        for fname in node.iter_function_identifiers() {
87            if UNSAFE_FUNCTIONS.contains(&fname) {
88                return PerPointParallelism::Sequential;
89            }
90        }
91
92        // Reads of a carry-tracked var that has NOT been written by an
93        // earlier equation in the block mean sample S+1's value
94        // depends on sample S's output.
95        for ident in node.iter_read_variable_identifiers() {
96            if CARRY_VARS.contains(&ident) && !already_written.contains(ident) {
97                return PerPointParallelism::Sequential;
98            }
99        }
100
101        // Writes outside the allowed output set propagate side effects
102        // across samples (qN, custom vars, …). The shared per_point
103        // outputs (`x`/`y`/`r`/`g`/`b`/`a`) and the re-seeded inputs
104        // (`sample`/`value1`/`value2`) are fine.
105        for ident in node.iter_write_variable_identifiers() {
106            if !CARRY_VARS.contains(&ident) && !INPUT_VARS.contains(&ident) {
107                return PerPointParallelism::Sequential;
108            }
109            already_written.insert(ident.to_string());
110        }
111    }
112
113    PerPointParallelism::Safe
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use crate::MilkEvaluator;
120
121    fn compile(eqs: &[&str]) -> Vec<Node> {
122        let mut eval = MilkEvaluator::new();
123        let owned: Vec<String> = eqs.iter().map(|s| s.to_string()).collect();
124        eval.compile_batch(&owned).unwrap()
125    }
126
127    #[test]
128    fn pure_output_writes_are_safe() {
129        let n = compile(&[
130            "x = sample",
131            "y = 0.5 - value1 * 0.5",
132            "r = 1",
133            "g = 0",
134            "b = 0",
135            "a = 1",
136        ]);
137        assert_eq!(analyse_per_point(&n), PerPointParallelism::Safe);
138    }
139
140    #[test]
141    fn read_of_x_before_write_is_carry() {
142        let n = compile(&["x = x + 0.001"]);
143        assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
144    }
145
146    #[test]
147    fn read_after_write_in_separate_equation_is_safe() {
148        // eq1 writes x from sample; eq2 reads x but the first write
149        // already happened in this iteration → no carry on x.
150        let n = compile(&["x = sample", "y = x * 2"]);
151        assert_eq!(analyse_per_point(&n), PerPointParallelism::Safe);
152    }
153
154    #[test]
155    fn write_to_q_var_is_unsafe() {
156        // q1 written by per_point would leak to next sample / next
157        // wave's per_frame.
158        let n = compile(&["q1 = sample", "x = sample"]);
159        assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
160    }
161
162    #[test]
163    fn write_to_custom_var_is_unsafe() {
164        let n = compile(&["mycarry = sample", "x = mycarry"]);
165        assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
166    }
167
168    #[test]
169    fn rand_call_taints_block() {
170        let n = compile(&["x = rand(100)/100"]);
171        assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
172    }
173
174    #[test]
175    fn gmegabuf_call_taints_block() {
176        let n = compile(&["x = gmegabuf(0)"]);
177        assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
178    }
179
180    #[test]
181    fn empty_block_is_safe_but_vacuous() {
182        // No equations → nothing reads carry → trivially safe (caller
183        // handles empty case before invoking us anyway).
184        assert_eq!(analyse_per_point(&[]), PerPointParallelism::Safe);
185    }
186
187    #[test]
188    fn writing_to_sample_input_is_safe() {
189        // sample/value1/value2 are re-seeded by the caller every
190        // iteration; intra-iteration writes to them don't leak.
191        let n = compile(&["sample = sample * 2", "x = sample"]);
192        assert_eq!(analyse_per_point(&n), PerPointParallelism::Safe);
193    }
194}