onedrop_eval/carry.rs
1//! Carry-dependency analysis for per-point equation blocks.
2//!
3//! Custom-wave `per_point` blocks run sequentially across the wave's
4//! samples and the convention is that each iteration sees the previous
5//! iteration's `x`/`y`/`r`/`g`/`b`/`a` as input ("trail across samples").
6//! In practice many presets don't actually depend on that carry — they
7//! write `x = …; y = …; r = …;` from the per-sample inputs (`sample`,
8//! `value1`, `value2`) and never read the previous values. In that case
9//! the samples are independent and the engine is free to evaluate them
10//! in parallel, which is a 4-8× win on dense waves.
11//!
12//! This module walks the compiled [`Node`]s once at preset-load time and
13//! classifies each per_point block as either *safe-to-parallelize* or
14//! *carry-dependent*. The check is conservative — a block is only marked
15//! safe if:
16//!
17//! 1. No carry-tracked variable (`x`, `y`, `r`, `g`, `b`, `a`) is read
18//! before being written in the block sequence (a read inside an
19//! equation that also writes the var is treated as carry-needing
20//! unless an earlier equation already wrote it).
21//! 2. No variable outside `{x, y, r, g, b, a, sample, value1, value2}`
22//! is written. Stray writes to `qN` / custom vars / per-frame vars
23//! leak state across samples.
24//! 3. No stateful / side-effecting function is called (`rand`,
25//! `gmegabuf`, `megabuf`, `gmegabuf_set`, `megabuf_set`, `exec2`,
26//! `exec3`, `loop`, `while`). These either carry per-eval state
27//! (gmegabuf is thread-local; samples on different workers would
28//! see split state) or are sample-order dependent (`rand`).
29
30use evalexpr::Node;
31use std::collections::HashSet;
32
33/// Variables that thread state across iterations of `eval_per_point`
34/// (the caller seeds them from the previous sample's output).
35pub const CARRY_VARS: &[&str] = &["x", "y", "r", "g", "b", "a"];
36
37/// Per-sample input variables the caller re-seeds every iteration —
38/// writes to these inside per_point are harmless because they get
39/// overwritten before the next sample runs.
40const INPUT_VARS: &[&str] = &["sample", "value1", "value2"];
41
42/// Functions whose semantics depend on call order or thread-local state.
43/// A per_point block calling any of these cannot be safely sample-parallelised.
44const UNSAFE_FUNCTIONS: &[&str] = &[
45 "rand",
46 "gmegabuf",
47 "megabuf",
48 "gmegabuf_set",
49 "megabuf_set",
50 "exec2",
51 "exec3",
52 "loop",
53 "while",
54];
55
56/// Classification of a per_point block for the parallel-samples pass.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum PerPointParallelism {
59 /// Samples can be evaluated independently — output of sample `S`
60 /// does not feed into sample `S+1`.
61 Safe,
62 /// Sequential evaluation only — at least one of: carry-read on
63 /// `{x,y,r,g,b,a}`, write outside the allowed output set, or call
64 /// to a stateful function.
65 Sequential,
66}
67
68impl PerPointParallelism {
69 /// Returns `true` for [`PerPointParallelism::Safe`].
70 #[inline]
71 pub fn is_safe(self) -> bool {
72 matches!(self, Self::Safe)
73 }
74}
75
76/// Analyse a compiled per_point block. Walks each equation's operator
77/// tree once and tracks a running set of "already-written" identifiers
78/// to distinguish "first read = carry" from "first read after write =
79/// uses fresh value".
80pub fn analyse_per_point(per_point: &[Node]) -> PerPointParallelism {
81 let mut already_written: HashSet<String> = HashSet::new();
82
83 for node in per_point {
84 // Functions first — a single occurrence of `rand` / `gmegabuf` /
85 // … taints the whole block.
86 for fname in node.iter_function_identifiers() {
87 if UNSAFE_FUNCTIONS.contains(&fname) {
88 return PerPointParallelism::Sequential;
89 }
90 }
91
92 // Reads of a carry-tracked var that has NOT been written by an
93 // earlier equation in the block mean sample S+1's value
94 // depends on sample S's output.
95 for ident in node.iter_read_variable_identifiers() {
96 if CARRY_VARS.contains(&ident) && !already_written.contains(ident) {
97 return PerPointParallelism::Sequential;
98 }
99 }
100
101 // Writes outside the allowed output set propagate side effects
102 // across samples (qN, custom vars, …). The shared per_point
103 // outputs (`x`/`y`/`r`/`g`/`b`/`a`) and the re-seeded inputs
104 // (`sample`/`value1`/`value2`) are fine.
105 for ident in node.iter_write_variable_identifiers() {
106 if !CARRY_VARS.contains(&ident) && !INPUT_VARS.contains(&ident) {
107 return PerPointParallelism::Sequential;
108 }
109 already_written.insert(ident.to_string());
110 }
111 }
112
113 PerPointParallelism::Safe
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use crate::MilkEvaluator;
120
121 fn compile(eqs: &[&str]) -> Vec<Node> {
122 let mut eval = MilkEvaluator::new();
123 let owned: Vec<String> = eqs.iter().map(|s| s.to_string()).collect();
124 eval.compile_batch(&owned).unwrap()
125 }
126
127 #[test]
128 fn pure_output_writes_are_safe() {
129 let n = compile(&[
130 "x = sample",
131 "y = 0.5 - value1 * 0.5",
132 "r = 1",
133 "g = 0",
134 "b = 0",
135 "a = 1",
136 ]);
137 assert_eq!(analyse_per_point(&n), PerPointParallelism::Safe);
138 }
139
140 #[test]
141 fn read_of_x_before_write_is_carry() {
142 let n = compile(&["x = x + 0.001"]);
143 assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
144 }
145
146 #[test]
147 fn read_after_write_in_separate_equation_is_safe() {
148 // eq1 writes x from sample; eq2 reads x but the first write
149 // already happened in this iteration → no carry on x.
150 let n = compile(&["x = sample", "y = x * 2"]);
151 assert_eq!(analyse_per_point(&n), PerPointParallelism::Safe);
152 }
153
154 #[test]
155 fn write_to_q_var_is_unsafe() {
156 // q1 written by per_point would leak to next sample / next
157 // wave's per_frame.
158 let n = compile(&["q1 = sample", "x = sample"]);
159 assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
160 }
161
162 #[test]
163 fn write_to_custom_var_is_unsafe() {
164 let n = compile(&["mycarry = sample", "x = mycarry"]);
165 assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
166 }
167
168 #[test]
169 fn rand_call_taints_block() {
170 let n = compile(&["x = rand(100)/100"]);
171 assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
172 }
173
174 #[test]
175 fn gmegabuf_call_taints_block() {
176 let n = compile(&["x = gmegabuf(0)"]);
177 assert_eq!(analyse_per_point(&n), PerPointParallelism::Sequential);
178 }
179
180 #[test]
181 fn empty_block_is_safe_but_vacuous() {
182 // No equations → nothing reads carry → trivially safe (caller
183 // handles empty case before invoking us anyway).
184 assert_eq!(analyse_per_point(&[]), PerPointParallelism::Safe);
185 }
186
187 #[test]
188 fn writing_to_sample_input_is_safe() {
189 // sample/value1/value2 are re-seeded by the caller every
190 // iteration; intra-iteration writes to them don't leak.
191 let n = compile(&["sample = sample * 2", "x = sample"]);
192 assert_eq!(analyse_per_point(&n), PerPointParallelism::Safe);
193 }
194}