onedrop_eval/
math_functions.rs

1//! Mathematical functions for MilkDrop expressions.
2//!
3//! This module provides all the mathematical functions needed for MilkDrop presets,
4//! as evalexpr 13.0 does not include trigonometric or advanced math functions by default.
5
6use evalexpr::{ContextWithMutableFunctions, DefaultNumericTypes, Function, HashMapContext, Value};
7
8/// Thread-local persistent scratch buffer for MD2 `gmegabuf` / `megabuf`.
9///
10/// evalexpr `Function` closures can't capture mutable state, so the backing
11/// array lives in a thread-local `RefCell<Vec<f64>>` cleared from
12/// `MilkEvaluator::new()` (via `gmegabuf::reset`). Reads at indices past the
13/// high-water mark return 0.0 (uninitialised); writes grow the vector lazily
14/// up to [`MAX_SLOTS`] so a `loop(1024*1024, gmegabuf(i)=0)` preset doesn't
15/// allocate 8 MB up-front for evaluators that never touch the buffer.
16pub mod gmegabuf {
17    use std::cell::RefCell;
18
19    /// MD2's documented cap (1 048 576 slots = 8 MB at f64). Indices outside
20    /// this range silently no-op so a malformed preset can't OOM the eval.
21    pub const MAX_SLOTS: usize = 1 << 20;
22
23    thread_local! {
24        static BUFFER: RefCell<Vec<f64>> = const { RefCell::new(Vec::new()) };
25    }
26
27    /// Clear the buffer. Called by `MilkEvaluator::new()` so each evaluator
28    /// starts with fresh state on the current thread. (Within a single
29    /// thread, evaluators can stomp on each other's buffer mid-frame; in
30    /// practice each preset is loaded → eval'd → dropped sequentially.)
31    pub fn reset() {
32        BUFFER.with(|b| b.borrow_mut().clear());
33    }
34
35    /// Read slot `idx` (clamped to `[0, MAX_SLOTS)`); returns 0.0 if out of
36    /// range or past the current high-water mark.
37    pub fn read(idx: f64) -> f64 {
38        if !idx.is_finite() {
39            return 0.0;
40        }
41        let i = idx as i64;
42        if i < 0 || (i as usize) >= MAX_SLOTS {
43            return 0.0;
44        }
45        let i = i as usize;
46        BUFFER.with(|b| b.borrow().get(i).copied().unwrap_or(0.0))
47    }
48
49    /// Write `val` to slot `idx`. Grows the buffer if needed; out-of-range
50    /// indices are silently ignored.
51    pub fn write(idx: f64, val: f64) {
52        if !idx.is_finite() {
53            return;
54        }
55        let i = idx as i64;
56        if i < 0 || (i as usize) >= MAX_SLOTS {
57            return;
58        }
59        let i = i as usize;
60        BUFFER.with(|b| {
61            let mut v = b.borrow_mut();
62            if i >= v.len() {
63                v.resize(i + 1, 0.0);
64            }
65            v[i] = val;
66        });
67    }
68}
69
70/// Coerce a `Value` to `f64`, accepting Int / Float / Boolean.
71///
72/// `Value::as_number()` only handles Int and Float, so any builtin driven from
73/// a Boolean operand (cmp result, `band`/`bor` arg after
74/// [`rewrite_logical_to_bandbor`]) trips on "Expected Number". MD2 EEL2's
75/// logical/numeric operators all collapse Boolean → 0.0/1.0, so we mirror
76/// that for our registered helpers (`band`, `bor`, `bnot`).
77fn coerce_to_f64(value: &Value) -> Result<f64, evalexpr::EvalexprError> {
78    match value {
79        Value::Boolean(b) => Ok(if *b { 1.0 } else { 0.0 }),
80        _ => value.as_number(),
81    }
82}
83
84/// Register all MilkDrop math functions in a HashMapContext.
85pub fn register_math_functions(context: &mut HashMapContext<DefaultNumericTypes>) {
86    // Trigonometric functions
87    context
88        .set_function(
89            "sin".into(),
90            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.sin()))),
91        )
92        .ok();
93
94    context
95        .set_function(
96            "cos".into(),
97            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.cos()))),
98        )
99        .ok();
100
101    context
102        .set_function(
103            "tan".into(),
104            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.tan()))),
105        )
106        .ok();
107
108    context
109        .set_function(
110            "asin".into(),
111            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.asin()))),
112        )
113        .ok();
114
115    context
116        .set_function(
117            "acos".into(),
118            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.acos()))),
119        )
120        .ok();
121
122    context
123        .set_function(
124            "atan".into(),
125            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.atan()))),
126        )
127        .ok();
128
129    context
130        .set_function(
131            "atan2".into(),
132            Function::new(|arg| {
133                if let Ok(tuple) = arg.as_tuple()
134                    && tuple.len() == 2
135                    && let (Ok(y), Ok(x)) = (tuple[0].as_number(), tuple[1].as_number())
136                {
137                    let y: f64 = y;
138                    let x: f64 = x;
139                    let y: f64 = y;
140                    let x: f64 = x;
141                    return Ok(Value::Float(y.atan2(x)));
142                }
143                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
144                    expected: 2..=2,
145                    actual: 1,
146                })
147            }),
148        )
149        .ok();
150
151    // Exponential and logarithmic functions
152    context
153        .set_function(
154            "sqrt".into(),
155            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.sqrt()))),
156        )
157        .ok();
158
159    context
160        .set_function(
161            "pow".into(),
162            Function::new(|arg| {
163                if let Ok(tuple) = arg.as_tuple()
164                    && tuple.len() == 2
165                    && let (Ok(base), Ok(exp)) = (tuple[0].as_number(), tuple[1].as_number())
166                {
167                    let base: f64 = base;
168                    let exp: f64 = exp;
169                    return Ok(Value::Float(base.powf(exp)));
170                }
171                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
172                    expected: 2..=2,
173                    actual: 1,
174                })
175            }),
176        )
177        .ok();
178
179    context
180        .set_function(
181            "exp".into(),
182            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.exp()))),
183        )
184        .ok();
185
186    context
187        .set_function(
188            "log".into(),
189            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.ln()))),
190        )
191        .ok();
192
193    context
194        .set_function(
195            "ln".into(),
196            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.ln()))),
197        )
198        .ok();
199
200    context
201        .set_function(
202            "log10".into(),
203            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.log10()))),
204        )
205        .ok();
206
207    // Absolute value and sign
208    context
209        .set_function(
210            "abs".into(),
211            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.abs()))),
212        )
213        .ok();
214
215    context
216        .set_function(
217            "sign".into(),
218            Function::new(|arg| {
219                arg.as_number().map(|n: f64| {
220                    if n > 0.0 {
221                        Value::Float(1.0)
222                    } else if n < 0.0 {
223                        Value::Float(-1.0)
224                    } else {
225                        Value::Float(0.0)
226                    }
227                })
228            }),
229        )
230        .ok();
231
232    // Rounding functions
233    context
234        .set_function(
235            "fract".into(),
236            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.fract()))),
237        )
238        .ok();
239
240    context
241        .set_function(
242            "trunc".into(),
243            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.trunc()))),
244        )
245        .ok();
246
247    // Modulo and clamping
248    context
249        .set_function(
250            "fmod".into(),
251            Function::new(|arg| {
252                if let Ok(tuple) = arg.as_tuple()
253                    && tuple.len() == 2
254                    && let (Ok(a), Ok(b)) = (tuple[0].as_number(), tuple[1].as_number())
255                {
256                    let a: f64 = a;
257                    let b: f64 = b;
258                    return Ok(Value::Float(a % b));
259                }
260                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
261                    expected: 2..=2,
262                    actual: 1,
263                })
264            }),
265        )
266        .ok();
267
268    context
269        .set_function(
270            "clamp".into(),
271            Function::new(|arg| {
272                if let Ok(tuple) = arg.as_tuple()
273                    && tuple.len() == 3
274                    && let (Ok(value), Ok(min_val), Ok(max_val)) = (
275                        tuple[0].as_number(),
276                        tuple[1].as_number(),
277                        tuple[2].as_number(),
278                    )
279                {
280                    let value: f64 = value;
281                    let min_val: f64 = min_val;
282                    let max_val: f64 = max_val;
283                    return Ok(Value::Float(value.max(min_val).min(max_val)));
284                }
285                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
286                    expected: 3..=3,
287                    actual: 1,
288                })
289            }),
290        )
291        .ok();
292
293    // Hyperbolic functions
294    context
295        .set_function(
296            "sinh".into(),
297            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.sinh()))),
298        )
299        .ok();
300
301    context
302        .set_function(
303            "cosh".into(),
304            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.cosh()))),
305        )
306        .ok();
307
308    context
309        .set_function(
310            "tanh".into(),
311            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.tanh()))),
312        )
313        .ok();
314
315    // Additional useful functions
316    context
317        .set_function(
318            "sqr".into(),
319            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n * n))),
320        )
321        .ok();
322
323    context
324        .set_function(
325            "rad".into(),
326            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.to_radians()))),
327        )
328        .ok();
329
330    context
331        .set_function(
332            "deg".into(),
333            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.to_degrees()))),
334        )
335        .ok();
336
337    // Random and comparison functions
338    context
339        .set_function(
340            "rand".into(),
341            Function::new(|arg| {
342                use std::time::{SystemTime, UNIX_EPOCH};
343                let max = arg.as_number()?;
344                let max: f64 = max;
345                // `duration_since(UNIX_EPOCH)` only fails on a pre-1970 clock;
346                // fall back to 0 so the equation still produces a number
347                // rather than panicking mid-frame.
348                let seed = SystemTime::now()
349                    .duration_since(UNIX_EPOCH)
350                    .map(|d| d.as_nanos())
351                    .unwrap_or(0);
352                let random = ((seed % 1000000) as f64 / 1000000.0) * max;
353                Ok(Value::Float(random))
354            }),
355        )
356        .ok();
357
358    context
359        .set_function(
360            "above".into(),
361            Function::new(|arg| {
362                if let Ok(tuple) = arg.as_tuple()
363                    && tuple.len() == 2
364                    && let (Ok(a), Ok(b)) = (tuple[0].as_number(), tuple[1].as_number())
365                {
366                    let a: f64 = a;
367                    let b: f64 = b;
368                    return Ok(Value::Float(if a > b { 1.0 } else { 0.0 }));
369                }
370                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
371                    expected: 2..=2,
372                    actual: 1,
373                })
374            }),
375        )
376        .ok();
377
378    context
379        .set_function(
380            "below".into(),
381            Function::new(|arg| {
382                if let Ok(tuple) = arg.as_tuple()
383                    && tuple.len() == 2
384                    && let (Ok(a), Ok(b)) = (tuple[0].as_number(), tuple[1].as_number())
385                {
386                    let a: f64 = a;
387                    let b: f64 = b;
388                    return Ok(Value::Float(if a < b { 1.0 } else { 0.0 }));
389                }
390                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
391                    expected: 2..=2,
392                    actual: 1,
393                })
394            }),
395        )
396        .ok();
397
398    context
399        .set_function(
400            "equal".into(),
401            Function::new(|arg| {
402                if let Ok(tuple) = arg.as_tuple()
403                    && tuple.len() == 2
404                    && let (Ok(a), Ok(b)) = (tuple[0].as_number(), tuple[1].as_number())
405                {
406                    let a: f64 = a;
407                    let b: f64 = b;
408                    return Ok(Value::Float(if (a - b).abs() < 1e-10 { 1.0 } else { 0.0 }));
409                }
410                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
411                    expected: 2..=2,
412                    actual: 1,
413                })
414            }),
415        )
416        .ok();
417
418    // Boolean functions — coercion accepts `Value::Boolean` so the `&&` /
419    // `||` / unary-`!` rewriters can hand us cmp results (Boolean) without
420    // `as_number()` rejecting.
421    context
422        .set_function(
423            "bnot".into(),
424            Function::new(|arg| {
425                coerce_to_f64(arg).map(|n: f64| Value::Float(if n == 0.0 { 1.0 } else { 0.0 }))
426            }),
427        )
428        .ok();
429
430    context
431        .set_function(
432            "band".into(),
433            Function::new(|arg| {
434                if let Ok(tuple) = arg.as_tuple()
435                    && tuple.len() == 2
436                    && let (Ok(a), Ok(b)) = (coerce_to_f64(&tuple[0]), coerce_to_f64(&tuple[1]))
437                {
438                    return Ok(Value::Float(if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 }));
439                }
440                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
441                    expected: 2..=2,
442                    actual: 1,
443                })
444            }),
445        )
446        .ok();
447
448    context
449        .set_function(
450            "bor".into(),
451            Function::new(|arg| {
452                if let Ok(tuple) = arg.as_tuple()
453                    && tuple.len() == 2
454                    && let (Ok(a), Ok(b)) = (coerce_to_f64(&tuple[0]), coerce_to_f64(&tuple[1]))
455                {
456                    return Ok(Value::Float(if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 }));
457                }
458                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
459                    expected: 2..=2,
460                    actual: 1,
461                })
462            }),
463        )
464        .ok();
465
466    // Type conversion
467    context
468        .set_function(
469            "int".into(),
470            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.trunc()))),
471        )
472        .ok();
473
474    // ---- Numeric-coercion overrides for `min` / `max` ----
475    //
476    // evalexpr 13 ships `min` and `max` as strict Float builtins; MD2 presets
477    // routinely call `max(0, res)` (Int literal as first arg) and the strict
478    // builtin errors with "Expected Float, got Int(0)". Override with
479    // `as_number()` versions that coerce Int and Boolean to Float
480    // transparently.
481    context
482        .set_function(
483            "max".into(),
484            Function::new(|arg| {
485                if let Ok(tuple) = arg.as_tuple()
486                    && tuple.len() == 2
487                    && let (Ok(a), Ok(b)) = (tuple[0].as_number(), tuple[1].as_number())
488                {
489                    let a: f64 = a;
490                    let b: f64 = b;
491                    return Ok(Value::Float(a.max(b)));
492                }
493                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
494                    expected: 2..=2,
495                    actual: 1,
496                })
497            }),
498        )
499        .ok();
500
501    context
502        .set_function(
503            "min".into(),
504            Function::new(|arg| {
505                if let Ok(tuple) = arg.as_tuple()
506                    && tuple.len() == 2
507                    && let (Ok(a), Ok(b)) = (tuple[0].as_number(), tuple[1].as_number())
508                {
509                    let a: f64 = a;
510                    let b: f64 = b;
511                    return Ok(Value::Float(a.min(b)));
512                }
513                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
514                    expected: 2..=2,
515                    actual: 1,
516                })
517            }),
518        )
519        .ok();
520
521    // `sigmoid(x, k)` — MD2 logistic: `1 / (1 + exp(-x*k))`. 49 / 470
522    // eval-failing presets in `test-presets-2` sample 2000 referenced this.
523    context
524        .set_function(
525            "sigmoid".into(),
526            Function::new(|arg| {
527                if let Ok(tuple) = arg.as_tuple()
528                    && tuple.len() == 2
529                    && let (Ok(x), Ok(k)) = (tuple[0].as_number(), tuple[1].as_number())
530                {
531                    let x: f64 = x;
532                    let k: f64 = k;
533                    return Ok(Value::Float(1.0 / (1.0 + (-x * k).exp())));
534                }
535                Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
536                    expected: 2..=2,
537                    actual: 1,
538                })
539            }),
540        )
541        .ok();
542
543    // `floor` / `ceil` / `round` — evalexpr 13 has them strict on Float;
544    // MD2 presets pass Int literals. Wrap with as_number() coercion so
545    // `floor(0)` and `ceil(1)` accept the literal cleanly.
546    context
547        .set_function(
548            "floor".into(),
549            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.floor()))),
550        )
551        .ok();
552
553    context
554        .set_function(
555            "ceil".into(),
556            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.ceil()))),
557        )
558        .ok();
559
560    context
561        .set_function(
562            "round".into(),
563            Function::new(|arg| arg.as_number().map(|n: f64| Value::Float(n.round()))),
564        )
565        .ok();
566
567    // `loop(N, body)` is intercepted by `MilkEvaluator::eval` before evalexpr
568    // sees it (see `evaluator::extract_top_level_loop`), so this stub is
569    // reached only for nested / non-top-level uses. Returning 0.0 keeps those
570    // equations parsing.
571    context
572        .set_function("loop".into(), Function::new(|_arg| Ok(Value::Float(0.0))))
573        .ok();
574
575    // EEL2 `while(cond)`, `exec2(a, b)`, `exec3(a, b, c)`. Like `loop`,
576    // they're intercepted before evalexpr sees them at the top level
577    // (`find_top_level_while_call` / `find_top_level_exec_call`). The stubs
578    // here only fire for nested / non-top-level uses where the args have
579    // already been collapsed to a single value by an outer interceptor.
580    // Returning 0.0 keeps those rare residuals parseable.
581    context
582        .set_function("while".into(), Function::new(|_arg| Ok(Value::Float(0.0))))
583        .ok();
584    context
585        .set_function("exec2".into(), Function::new(|_arg| Ok(Value::Float(0.0))))
586        .ok();
587    context
588        .set_function("exec3".into(), Function::new(|_arg| Ok(Value::Float(0.0))))
589        .ok();
590
591    // MD2 `gmegabuf` / `megabuf` are persistent N-slot scratch buffers
592    // (default 1 048 576 slots, indexed by f64 cast to usize). evalexpr
593    // `Function`s can't carry mutable state, so the backing lives in a
594    // `thread_local!` `RefCell<Vec<f64>>` cleared on every
595    // `MilkEvaluator::new()`. Reads out of range return 0.0; writes grow
596    // the buffer up to `gmegabuf::MAX_SLOTS`.
597    //
598    // Semantically, MD2 distinguishes `megabuf` (per-equation-block) from
599    // `gmegabuf` (preset-wide). Our backing is per-evaluator for both —
600    // close enough that the 163 init-loop presets in the corpus run
601    // their N-fold zero-fill without spilling state across presets, and
602    // any subsequent reads at the same indices see the values they wrote.
603    context
604        .set_function(
605            "gmegabuf".into(),
606            Function::new(|arg| {
607                let idx = arg.as_number().unwrap_or(0.0);
608                Ok(Value::Float(crate::math_functions::gmegabuf::read(idx)))
609            }),
610        )
611        .ok();
612
613    context
614        .set_function(
615            "megabuf".into(),
616            Function::new(|arg| {
617                let idx = arg.as_number().unwrap_or(0.0);
618                Ok(Value::Float(crate::math_functions::gmegabuf::read(idx)))
619            }),
620        )
621        .ok();
622
623    context
624        .set_function(
625            "gmegabuf_set".into(),
626            Function::new(|arg| {
627                if let Ok(tuple) = arg.as_tuple()
628                    && tuple.len() == 2
629                    && let (Ok(i), Ok(v)) = (tuple[0].as_number(), tuple[1].as_number())
630                {
631                    let i: f64 = i;
632                    let v: f64 = v;
633                    crate::math_functions::gmegabuf::write(i, v);
634                    return Ok(Value::Float(v));
635                }
636                Ok(Value::Float(0.0))
637            }),
638        )
639        .ok();
640
641    context
642        .set_function(
643            "megabuf_set".into(),
644            Function::new(|arg| {
645                if let Ok(tuple) = arg.as_tuple()
646                    && tuple.len() == 2
647                    && let (Ok(i), Ok(v)) = (tuple[0].as_number(), tuple[1].as_number())
648                {
649                    let i: f64 = i;
650                    let v: f64 = v;
651                    crate::math_functions::gmegabuf::write(i, v);
652                    return Ok(Value::Float(v));
653                }
654                Ok(Value::Float(0.0))
655            }),
656        )
657        .ok();
658
659    // MilkDrop-style if function (accepts Float condition)
660    context
661        .set_function(
662            "milkif".into(),
663            Function::new(|arg| {
664                if let Ok(tuple) = arg.as_tuple() {
665                    if tuple.len() == 3 {
666                        // Get condition (Float or Boolean)
667                        let condition = match &tuple[0] {
668                            Value::Float(f) => *f != 0.0,
669                            Value::Int(i) => *i != 0,
670                            Value::Boolean(b) => *b,
671                            _ => {
672                                return Err(evalexpr::EvalexprError::TypeError {
673                                    expected: vec![
674                                        evalexpr::ValueType::Float,
675                                        evalexpr::ValueType::Boolean,
676                                    ],
677                                    actual: tuple[0].clone(),
678                                });
679                            }
680                        };
681
682                        // Return true_val or false_val. MD2's `if()` is
683                        // numeric, so coerce Int and Boolean branches to
684                        // Float — otherwise a `milkif(cond, a>b, c>d)`
685                        // would leak a Boolean into the outer expression
686                        // and crash the next arithmetic op with
687                        // "Expected Float, got Boolean".
688                        let result = if condition { &tuple[1] } else { &tuple[2] };
689                        match result {
690                            Value::Int(i) => Ok(Value::Float(*i as f64)),
691                            Value::Boolean(b) => Ok(Value::Float(if *b { 1.0 } else { 0.0 })),
692                            other => Ok(other.clone()),
693                        }
694                    } else {
695                        Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
696                            expected: 3..=3,
697                            actual: tuple.len(),
698                        })
699                    }
700                } else {
701                    Err(evalexpr::EvalexprError::WrongFunctionArgumentAmount {
702                        expected: 3..=3,
703                        actual: 1,
704                    })
705                }
706            }),
707        )
708        .ok();
709}
710
711/// List of all registered math functions.
712pub fn list_math_functions() -> Vec<&'static str> {
713    vec![
714        // Trigonometric
715        "sin", "cos", "tan", "asin", "acos", "atan", "atan2",
716        // Exponential and logarithmic
717        "sqrt", "pow", "exp", "log", "ln", "log10", // Absolute and sign
718        "abs", "sign", // Rounding
719        "fract", "trunc", "floor", "ceil", "round", // Modulo and clamping
720        "fmod", "clamp", "min", "max", // Hyperbolic
721        "sinh", "cosh", "tanh", // Additional
722        "sqr", "rad", "deg", "sigmoid", // Random and comparison
723        "rand", "above", "below", "equal", // Boolean
724        "bnot", "band", "bor", // Type conversion
725        "int", // Control flow
726        "milkif",
727    ]
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733    use approx::assert_relative_eq;
734    use evalexpr::ContextWithMutableVariables;
735
736    #[test]
737    fn test_register_math_functions() {
738        let mut context = HashMapContext::<DefaultNumericTypes>::new();
739        register_math_functions(&mut context);
740
741        // Test that basic functions work
742        assert!(evalexpr::eval_number_with_context("sin(0)", &context).is_ok());
743        assert!(evalexpr::eval_number_with_context("cos(0)", &context).is_ok());
744        assert!(evalexpr::eval_number_with_context("sqrt(4)", &context).is_ok());
745    }
746
747    #[test]
748    fn test_sin_function() {
749        let mut context = HashMapContext::<DefaultNumericTypes>::new();
750        register_math_functions(&mut context);
751
752        let result = evalexpr::eval_number_with_context("sin(0)", &context).unwrap();
753        assert_relative_eq!(result, 0.0, epsilon = 1e-10);
754
755        let result =
756            evalexpr::eval_number_with_context("sin(1.5707963267948966)", &context).unwrap();
757        assert_relative_eq!(result, 1.0, epsilon = 1e-10);
758    }
759
760    #[test]
761    fn test_cos_function() {
762        let mut context = HashMapContext::<DefaultNumericTypes>::new();
763        register_math_functions(&mut context);
764
765        let result = evalexpr::eval_number_with_context("cos(0)", &context).unwrap();
766        assert_relative_eq!(result, 1.0, epsilon = 1e-10);
767    }
768
769    #[test]
770    fn test_sqrt_function() {
771        let mut context = HashMapContext::<DefaultNumericTypes>::new();
772        register_math_functions(&mut context);
773
774        let result = evalexpr::eval_number_with_context("sqrt(16)", &context).unwrap();
775        assert_relative_eq!(result, 4.0, epsilon = 1e-10);
776    }
777
778    #[test]
779    fn test_abs_function() {
780        let mut context = HashMapContext::<DefaultNumericTypes>::new();
781        register_math_functions(&mut context);
782
783        let result = evalexpr::eval_number_with_context("abs(-5)", &context).unwrap();
784        assert_relative_eq!(result, 5.0, epsilon = 1e-10);
785    }
786
787    #[test]
788    fn test_pow_function() {
789        let mut context = HashMapContext::<DefaultNumericTypes>::new();
790        register_math_functions(&mut context);
791
792        let result = evalexpr::eval_number_with_context("pow(2, 3)", &context).unwrap();
793        assert_relative_eq!(result, 8.0, epsilon = 1e-10);
794    }
795
796    #[test]
797    fn test_complex_expression() {
798        let mut context = HashMapContext::<DefaultNumericTypes>::new();
799        register_math_functions(&mut context);
800        context.set_value("time".into(), Value::Float(1.0)).unwrap();
801
802        let result = evalexpr::eval_number_with_context(
803            "sin(time) * cos(time * 2) + sqrt(abs(time - 0.5))",
804            &context,
805        )
806        .unwrap();
807
808        let time = 1.0_f64;
809        let expected = time.sin() * (time * 2.0).cos() + (time - 0.5).abs().sqrt();
810        assert_relative_eq!(result, expected, epsilon = 1e-10);
811    }
812}