1use std::collections::HashMap;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum WgslType {
30 F32,
31 I32,
32 Vec2F,
33 Vec3F,
34 Vec4F,
35 Mat2F,
36 Mat3F,
37 Mat4F,
38 Bool,
39 Unknown,
40}
41
42impl WgslType {
43 pub fn is_vec(self) -> bool {
45 matches!(self, Self::Vec2F | Self::Vec3F | Self::Vec4F)
46 }
47
48 pub fn is_scalar(self) -> bool {
51 matches!(self, Self::F32 | Self::I32 | Self::Bool)
52 }
53
54 pub fn wgsl_name(self) -> &'static str {
56 match self {
57 Self::F32 => "f32",
58 Self::I32 => "i32",
59 Self::Vec2F => "vec2<f32>",
60 Self::Vec3F => "vec3<f32>",
61 Self::Vec4F => "vec4<f32>",
62 Self::Mat2F => "mat2x2<f32>",
63 Self::Mat3F => "mat3x3<f32>",
64 Self::Mat4F => "mat4x4<f32>",
65 Self::Bool => "bool",
66 Self::Unknown => "/* unknown */",
67 }
68 }
69
70 fn from_decl_str(s: &str) -> Self {
73 match s.trim() {
74 "f32" => Self::F32,
75 "i32" => Self::I32,
76 "u32" => Self::I32, "vec2<f32>" => Self::Vec2F,
78 "vec3<f32>" => Self::Vec3F,
79 "vec4<f32>" => Self::Vec4F,
80 "mat2x2<f32>" => Self::Mat2F,
81 "mat3x3<f32>" => Self::Mat3F,
82 "mat4x4<f32>" => Self::Mat4F,
83 "bool" => Self::Bool,
84 _ => Self::Unknown,
85 }
86 }
87}
88
89pub struct SymbolTable {
94 pub locals: HashMap<String, WgslType>,
95}
96
97impl SymbolTable {
98 pub fn from_source(src: &str) -> Self {
102 let mut locals = HashMap::new();
103
104 for (n, t) in WRAPPER_PRELUDE_LOCALS {
108 locals.insert((*n).to_string(), *t);
109 }
110
111 let bytes = src.as_bytes();
113 let mut i = 0;
114 while i < bytes.len() {
115 if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
118 while i < bytes.len() && bytes[i] != b'\n' {
119 i += 1;
120 }
121 continue;
122 }
123 if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
124 i += 2;
125 while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
126 i += 1;
127 }
128 i += 2;
129 continue;
130 }
131
132 let kw_len = match keyword_at(bytes, i, &["var", "let"]) {
133 Some(n) => n,
134 None => {
135 i += 1;
136 continue;
137 }
138 };
139 let mut j = i + kw_len;
140 while j < bytes.len() && bytes[j].is_ascii_whitespace() {
141 j += 1;
142 }
143 let name_start = j;
144 while j < bytes.len() && (bytes[j].is_ascii_alphanumeric() || bytes[j] == b'_') {
145 j += 1;
146 }
147 if j == name_start {
148 i = j + 1;
149 continue;
150 }
151 let name = &src[name_start..j];
152 while j < bytes.len() && bytes[j].is_ascii_whitespace() {
153 j += 1;
154 }
155 if j >= bytes.len() || bytes[j] != b':' {
156 i = j + 1;
158 continue;
159 }
160 j += 1;
161 let ty_start = j;
163 let mut depth_angle = 0i32;
164 while j < bytes.len() {
165 match bytes[j] {
166 b'<' => depth_angle += 1,
167 b'>' => depth_angle -= 1,
168 b'=' | b';' if depth_angle == 0 => break,
169 _ => {}
170 }
171 j += 1;
172 }
173 let ty = WgslType::from_decl_str(&src[ty_start..j]);
174 if !matches!(ty, WgslType::Unknown) {
175 use std::collections::hash_map::Entry;
192 match locals.entry(name.to_string()) {
193 Entry::Vacant(e) => {
194 e.insert(ty);
195 }
196 Entry::Occupied(mut e) => {
197 let prev = *e.get();
198 if prev != ty
199 && (prev.is_scalar() != ty.is_scalar() || prev.is_vec() != ty.is_vec())
200 {
201 e.insert(WgslType::Unknown);
202 } else {
203 e.insert(ty);
206 }
207 }
208 }
209 }
210 i = j;
211 }
212
213 Self { locals }
214 }
215
216 pub fn lookup(&self, name: &str) -> Option<WgslType> {
217 self.locals.get(name).copied()
218 }
219
220 pub fn infer_expr_type(&self, expr: &str) -> WgslType {
236 let stripped = strip_comments(expr);
242 let expr = strip_outer_parens(stripped.trim());
243
244 if expr.is_empty() {
245 return WgslType::Unknown;
246 }
247
248 if is_numeric_literal(expr) {
250 return WgslType::F32;
251 }
252
253 if is_identifier(expr) {
255 if let Some(t) = self.lookup(expr) {
256 return t;
257 }
258 return WgslType::Unknown;
261 }
262
263 if let Some((head, args)) = split_call(expr) {
265 if let Some(t) = constructor_type(head) {
266 return t;
267 }
268 if let Some(t) = known_call_return_type(head) {
269 return t;
270 }
271 if POLY_BUILTINS.contains(&head) {
276 let mut smallest: Option<WgslType> = None;
277 for a in split_top_level_commas(args) {
278 let t = self.infer_expr_type(a.trim());
279 if t.is_vec() {
280 smallest = Some(match smallest {
281 None => t,
282 Some(s) => narrower(s, t),
283 });
284 }
285 }
286 return smallest.unwrap_or(WgslType::F32);
287 }
288 }
289
290 if let Some(operands) = split_binop_operands(expr) {
294 let mut widest = WgslType::Unknown;
295 for op in &operands {
296 let t = self.infer_expr_type(op);
297 widest = widen(widest, t);
298 }
299 if !matches!(widest, WgslType::Unknown) {
300 return widest;
301 }
302 }
303
304 if let Some((prefix, comp)) = split_last_swizzle(expr)
308 && is_swizzle_components(comp)
309 {
310 let prefix_ty = self.infer_expr_type(prefix);
311 if prefix_ty.is_vec() {
312 return swizzle_target_type(comp.len());
313 }
314 }
315
316 if matches!(expr.as_bytes().first(), Some(b'-' | b'+' | b'!')) {
325 return self.infer_expr_type(&expr[1..]);
326 }
327
328 WgslType::Unknown
329 }
330}
331
332const WRAPPER_PRELUDE_LOCALS: &[(&str, WgslType)] = &[
335 ("uv", WgslType::Vec2F),
336 ("uv_orig", WgslType::Vec2F),
337 ("rad", WgslType::F32),
338 ("ang", WgslType::F32),
339 ("ret", WgslType::Vec3F),
340 ("color", WgslType::Vec3F),
341 ("texsize", WgslType::Vec4F),
342 ("aspect", WgslType::Vec4F),
343 ("time", WgslType::F32),
344 ("fps", WgslType::F32),
345 ("frame", WgslType::F32),
346 ("progress", WgslType::F32),
347 ("bass", WgslType::F32),
348 ("mid", WgslType::F32),
349 ("treb", WgslType::F32),
350 ("vol", WgslType::F32),
351 ("bass_att", WgslType::F32),
352 ("mid_att", WgslType::F32),
353 ("treb_att", WgslType::F32),
354 ("vol_att", WgslType::F32),
355 ("rand_preset", WgslType::Vec4F),
356 ("rand_frame", WgslType::Vec4F),
357 ("slow_roam_cos", WgslType::Vec4F),
358 ("slow_roam_sin", WgslType::Vec4F),
359 ("roam_cos", WgslType::Vec4F),
360 ("roam_sin", WgslType::Vec4F),
361 ("blur1_min", WgslType::F32),
362 ("blur1_max", WgslType::F32),
363 ("blur2_min", WgslType::F32),
364 ("blur2_max", WgslType::F32),
365 ("blur3_min", WgslType::F32),
366 ("blur3_max", WgslType::F32),
367 ("hue_shader", WgslType::Vec3F),
368 ("g_fTexSize", WgslType::Vec4F),
369 ("texsize_noise_lq", WgslType::Vec4F),
370 ("texsize_noise_lq_lite", WgslType::Vec4F),
371 ("texsize_noise_mq", WgslType::Vec4F),
372 ("texsize_noise_hq", WgslType::Vec4F),
373 ("texsize_noisevol_lq", WgslType::Vec4F),
374 ("texsize_noisevol_hq", WgslType::Vec4F),
375 ("q1", WgslType::F32),
377 ("q2", WgslType::F32),
378 ("q3", WgslType::F32),
379 ("q4", WgslType::F32),
380 ("q5", WgslType::F32),
381 ("q6", WgslType::F32),
382 ("q7", WgslType::F32),
383 ("q8", WgslType::F32),
384 ("q9", WgslType::F32),
385 ("q10", WgslType::F32),
386 ("q11", WgslType::F32),
387 ("q12", WgslType::F32),
388 ("q13", WgslType::F32),
389 ("q14", WgslType::F32),
390 ("q15", WgslType::F32),
391 ("q16", WgslType::F32),
392 ("q17", WgslType::F32),
393 ("q18", WgslType::F32),
394 ("q19", WgslType::F32),
395 ("q20", WgslType::F32),
396 ("q21", WgslType::F32),
397 ("q22", WgslType::F32),
398 ("q23", WgslType::F32),
399 ("q24", WgslType::F32),
400 ("q25", WgslType::F32),
401 ("q26", WgslType::F32),
402 ("q27", WgslType::F32),
403 ("q28", WgslType::F32),
404 ("q29", WgslType::F32),
405 ("q30", WgslType::F32),
406 ("q31", WgslType::F32),
407 ("q32", WgslType::F32),
408 ("M_PI", WgslType::F32),
409 ("M_PI_2", WgslType::F32),
410 ("M_INV_PI", WgslType::F32),
411 ("M_INV_PI_2", WgslType::F32),
412];
413
414fn keyword_at(bytes: &[u8], i: usize, kws: &[&str]) -> Option<usize> {
415 if i > 0 {
416 let prev = bytes[i - 1];
417 if prev.is_ascii_alphanumeric() || prev == b'_' {
418 return None;
419 }
420 }
421 for kw in kws {
422 let len = kw.len();
423 if i + len <= bytes.len() && &bytes[i..i + len] == kw.as_bytes() {
424 let next = bytes.get(i + len).copied();
425 let is_id = next.is_some_and(|c| c.is_ascii_alphanumeric() || c == b'_');
426 if !is_id {
427 return Some(len);
428 }
429 }
430 }
431 None
432}
433
434fn strip_comments(src: &str) -> String {
439 let bytes = src.as_bytes();
440 let mut out = Vec::with_capacity(bytes.len());
441 let mut i = 0;
442 while i < bytes.len() {
443 if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
444 while i < bytes.len() && bytes[i] != b'\n' {
445 out.push(b' ');
446 i += 1;
447 }
448 } else if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
449 while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
453 out.push(if bytes[i] == b'\n' { b'\n' } else { b' ' });
454 i += 1;
455 }
456 if i + 1 < bytes.len() {
458 out.push(b' ');
459 out.push(b' ');
460 i += 2;
461 }
462 } else {
463 out.push(bytes[i]);
464 i += 1;
465 }
466 }
467 String::from_utf8(out).expect("comment-stripping preserved UTF-8")
471}
472
473fn strip_outer_parens(expr: &str) -> &str {
474 let mut e = expr;
475 loop {
476 let trimmed = e.trim();
477 if !trimmed.starts_with('(') || !trimmed.ends_with(')') {
478 return trimmed;
479 }
480 let bytes = trimmed.as_bytes();
482 let mut depth = 0i32;
483 let mut closed_at = None;
484 for (i, &b) in bytes.iter().enumerate() {
485 match b {
486 b'(' => depth += 1,
487 b')' => {
488 depth -= 1;
489 if depth == 0 {
490 closed_at = Some(i);
491 break;
492 }
493 }
494 _ => {}
495 }
496 }
497 if closed_at == Some(bytes.len() - 1) {
498 e = &trimmed[1..bytes.len() - 1];
499 } else {
500 return trimmed;
501 }
502 }
503}
504
505fn is_numeric_literal(s: &str) -> bool {
506 let s = s.trim();
507 if s.is_empty() {
508 return false;
509 }
510 let s = s.trim_end_matches(['f', 'i', 'u', 'h']);
512 let core = s.strip_prefix('-').unwrap_or(s);
514 !core.is_empty()
515 && core
516 .chars()
517 .all(|c| c.is_ascii_digit() || c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-')
518 && core.chars().any(|c| c.is_ascii_digit())
519}
520
521fn is_identifier(s: &str) -> bool {
522 let mut chars = s.chars();
523 match chars.next() {
524 Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
525 _ => return false,
526 }
527 chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
528}
529
530fn split_call(expr: &str) -> Option<(&str, &str)> {
533 let bytes = expr.as_bytes();
534 let open = expr.find('(')?;
535 if open == 0 {
536 return None;
537 }
538 if bytes.last() != Some(&b')') {
541 return None;
542 }
543 let mut depth = 0i32;
545 for (i, &b) in bytes.iter().enumerate().skip(open) {
546 match b {
547 b'(' => depth += 1,
548 b')' => {
549 depth -= 1;
550 if depth == 0 {
551 if i != bytes.len() - 1 {
552 return None;
553 }
554 return Some((expr[..open].trim(), &expr[open + 1..i]));
555 }
556 }
557 _ => {}
558 }
559 }
560 None
561}
562
563fn constructor_type(head: &str) -> Option<WgslType> {
564 Some(match head {
565 "vec2<f32>" => WgslType::Vec2F,
566 "vec3<f32>" => WgslType::Vec3F,
567 "vec4<f32>" => WgslType::Vec4F,
568 "mat2x2<f32>" => WgslType::Mat2F,
569 "mat3x3<f32>" => WgslType::Mat3F,
570 "mat4x4<f32>" => WgslType::Mat4F,
571 "f32" => WgslType::F32,
572 "i32" => WgslType::I32,
573 _ => return None,
574 })
575}
576
577fn known_call_return_type(head: &str) -> Option<WgslType> {
578 Some(match head {
579 "GetPixel" | "GetBlur1" | "GetBlur2" | "GetBlur3" => WgslType::Vec3F,
581 "lum" => WgslType::F32,
582 "length" | "distance" | "dot" => WgslType::F32,
584 "cross" => WgslType::Vec3F,
585 "textureSample" => WgslType::Vec4F,
586 _ => return None,
595 })
596}
597
598fn split_last_swizzle(expr: &str) -> Option<(&str, &str)> {
601 let bytes = expr.as_bytes();
602 let mut depth_paren = 0i32;
603 let mut depth_angle = 0i32;
604 let mut last_dot = None;
605 for (i, &b) in bytes.iter().enumerate() {
606 match b {
607 b'(' => depth_paren += 1,
608 b')' => depth_paren -= 1,
609 b'<' => depth_angle += 1,
610 b'>' => depth_angle -= 1,
611 b'.' if depth_paren == 0 && depth_angle == 0 => last_dot = Some(i),
612 _ => {}
613 }
614 }
615 let dot = last_dot?;
616 let pre = expr[..dot].trim();
618 if pre.is_empty() || is_numeric_literal(pre) {
619 return None;
620 }
621 Some((pre, &expr[dot + 1..]))
622}
623
624fn is_swizzle_components(s: &str) -> bool {
625 !s.is_empty()
626 && s.len() <= 4
627 && s.chars()
628 .all(|c| matches!(c, 'x' | 'y' | 'z' | 'w' | 'r' | 'g' | 'b' | 'a'))
629}
630
631pub(crate) fn vec_size(t: WgslType) -> usize {
635 match t {
636 WgslType::Vec2F => 2,
637 WgslType::Vec3F => 3,
638 WgslType::Vec4F => 4,
639 _ => 0,
640 }
641}
642
643pub(crate) fn vec_of_size(n: usize) -> WgslType {
648 match n {
649 1 => WgslType::F32,
650 2 => WgslType::Vec2F,
651 3 => WgslType::Vec3F,
652 4 => WgslType::Vec4F,
653 _ => WgslType::Unknown,
654 }
655}
656
657fn swizzle_target_type(len: usize) -> WgslType {
658 match len {
659 1 => WgslType::F32,
660 2 => WgslType::Vec2F,
661 3 => WgslType::Vec3F,
662 4 => WgslType::Vec4F,
663 _ => WgslType::Unknown,
664 }
665}
666
667fn split_binop_operands(expr: &str) -> Option<Vec<&str>> {
671 let bytes = expr.as_bytes();
672 let mut depth_paren = 0i32;
673 let mut depth_angle = 0i32;
674 let mut depth_bracket = 0i32;
675 let mut splits = Vec::new();
676 let mut prev_was_op_or_start = true;
677 for (i, &b) in bytes.iter().enumerate() {
678 match b {
679 b'(' => {
680 depth_paren += 1;
681 prev_was_op_or_start = false;
682 }
683 b')' => {
684 depth_paren -= 1;
685 prev_was_op_or_start = false;
686 }
687 b'<' => {
688 depth_angle += 1;
689 prev_was_op_or_start = false;
690 }
691 b'>' => {
692 depth_angle -= 1;
693 prev_was_op_or_start = false;
694 }
695 b'[' => {
696 depth_bracket += 1;
697 prev_was_op_or_start = false;
698 }
699 b']' => {
700 depth_bracket -= 1;
701 prev_was_op_or_start = false;
702 }
703 b'+' | b'-' | b'*' | b'/'
704 if depth_paren == 0
705 && depth_angle == 0
706 && depth_bracket == 0
707 && !prev_was_op_or_start =>
708 {
709 splits.push(i);
710 prev_was_op_or_start = true;
711 }
712 c if c.is_ascii_whitespace() => {}
713 _ => prev_was_op_or_start = false,
714 }
715 }
716 if splits.is_empty() {
717 return None;
718 }
719 let mut out = Vec::with_capacity(splits.len() + 1);
720 let mut start = 0;
721 for &s in &splits {
722 out.push(&expr[start..s]);
723 start = s + 1;
724 }
725 out.push(&expr[start..]);
726 Some(out)
727}
728
729fn widen(a: WgslType, b: WgslType) -> WgslType {
732 match (a, b) {
733 (WgslType::Unknown, x) | (x, WgslType::Unknown) => x,
734 (x, y) if x == y => x,
735 (WgslType::Vec4F, _) | (_, WgslType::Vec4F) => WgslType::Vec4F,
736 (WgslType::Vec3F, _) | (_, WgslType::Vec3F) => WgslType::Vec3F,
737 (WgslType::Vec2F, _) | (_, WgslType::Vec2F) => WgslType::Vec2F,
738 _ => a,
739 }
740}
741
742const BROADCAST_BUILTINS: &[&str] = &[
751 "clamp",
752 "min",
753 "max",
754 "mix",
755 "step",
756 "smoothstep",
757 "pow",
758 "dot",
762 "cross",
767];
768
769const POLY_BUILTINS: &[&str] = &[
775 "clamp",
776 "min",
777 "max",
778 "mix",
779 "step",
780 "smoothstep",
781 "pow",
782 "abs",
783 "sign",
784 "floor",
785 "ceil",
786 "fract",
787 "exp",
788 "log",
789 "sin",
790 "cos",
791 "tan",
792 "sqrt",
793 "normalize",
794];
795
796pub fn inject_broadcasts(src: &str, table: &SymbolTable) -> String {
806 let bytes = src.as_bytes();
807 let mut out = String::with_capacity(src.len() + 64);
808 let mut i = 0;
809
810 while i < bytes.len() {
811 let mut matched = None;
813 for name in BROADCAST_BUILTINS {
814 let len = name.len();
815 if i + len < bytes.len()
816 && &bytes[i..i + len] == name.as_bytes()
817 && bytes[i + len] == b'('
818 && (i == 0 || !(bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_'))
819 {
820 matched = Some(*name);
821 break;
822 }
823 }
824
825 let Some(name) = matched else {
826 out.push(bytes[i] as char);
827 i += 1;
828 continue;
829 };
830
831 let arg_start = i + name.len() + 1;
833 let mut j = arg_start;
834 let mut depth = 1i32;
835 while j < bytes.len() {
836 match bytes[j] {
837 b'(' => depth += 1,
838 b')' => {
839 depth -= 1;
840 if depth == 0 {
841 break;
842 }
843 }
844 _ => {}
845 }
846 j += 1;
847 }
848 if j >= bytes.len() {
849 out.push_str(&src[i..]);
851 return out;
852 }
853
854 let args_text = &src[arg_start..j];
855 let raw_args = split_top_level_commas(args_text);
856
857 let rewritten_args: Vec<String> = raw_args
862 .iter()
863 .map(|a| inject_broadcasts(a.trim(), table))
864 .collect();
865 let arg_types: Vec<WgslType> = rewritten_args
866 .iter()
867 .map(|a| table.infer_expr_type(a))
868 .collect();
869
870 let target = arg_types
876 .iter()
877 .copied()
878 .filter(|t| t.is_vec())
879 .reduce(narrower);
880
881 out.push_str(name);
882 out.push('(');
883 match target {
884 Some(vec_ty) => {
885 for (k, (arg, ty)) in rewritten_args.iter().zip(arg_types.iter()).enumerate() {
886 if k > 0 {
887 out.push_str(", ");
888 }
889 let wrap = arg_wrap(*ty, vec_ty, arg);
890 match wrap {
891 ArgWrap::Broadcast => {
892 out.push_str(vec_ty.wgsl_name());
893 out.push('(');
894 out.push_str(arg);
895 out.push(')');
896 }
897 ArgWrap::Truncate(swizzle) => {
898 out.push('(');
899 out.push_str(arg);
900 out.push(')');
901 out.push_str(swizzle);
902 }
903 ArgWrap::None => {
904 out.push_str(arg);
905 }
906 }
907 }
908 }
909 None => {
910 for (k, arg) in rewritten_args.iter().enumerate() {
912 if k > 0 {
913 out.push_str(", ");
914 }
915 out.push_str(arg);
916 }
917 }
918 }
919 out.push(')');
920 i = j + 1;
921 }
922
923 out
924}
925
926enum ArgWrap {
927 None,
928 Broadcast,
929 Truncate(&'static str),
930}
931
932fn arg_wrap(arg_ty: WgslType, target: WgslType, arg: &str) -> ArgWrap {
938 if arg_ty == target {
939 return ArgWrap::None;
940 }
941 if arg_ty.is_scalar() {
942 return ArgWrap::Broadcast;
943 }
944 if matches!(arg_ty, WgslType::Unknown) && is_numeric_literal(arg.trim()) {
945 return ArgWrap::Broadcast;
946 }
947 if arg_ty.is_vec() && target.is_vec() && vec_size(arg_ty) > vec_size(target) {
948 return ArgWrap::Truncate(match vec_size(target) {
949 2 => ".xy",
950 3 => ".xyz",
951 _ => "",
952 });
953 }
954 ArgWrap::None
955}
956
957fn narrower(a: WgslType, b: WgslType) -> WgslType {
961 match (a, b) {
962 (WgslType::Unknown, x) | (x, WgslType::Unknown) => x,
963 (x, y) if x == y => x,
964 (WgslType::Vec2F, _) | (_, WgslType::Vec2F) => WgslType::Vec2F,
965 (WgslType::Vec3F, _) | (_, WgslType::Vec3F) => WgslType::Vec3F,
966 (WgslType::Vec4F, _) | (_, WgslType::Vec4F) => WgslType::Vec4F,
967 _ => a,
968 }
969}
970
971fn split_top_level_commas(s: &str) -> Vec<&str> {
972 let bytes = s.as_bytes();
973 let mut out = Vec::new();
974 let mut depth_paren = 0i32;
975 let mut depth_angle = 0i32;
976 let mut start = 0usize;
977 for (i, &b) in bytes.iter().enumerate() {
978 match b {
979 b'(' => depth_paren += 1,
980 b')' => depth_paren -= 1,
981 b'<' => depth_angle += 1,
982 b'>' => depth_angle -= 1,
983 b',' if depth_paren == 0 && depth_angle == 0 => {
984 out.push(&s[start..i]);
985 start = i + 1;
986 }
987 _ => {}
988 }
989 }
990 out.push(&s[start..]);
991 out
992}
993
994pub fn inject_assignment_coercions(src: &str, table: &SymbolTable) -> String {
1018 let bytes = src.as_bytes();
1019 let mut out = String::with_capacity(src.len() + 64);
1020 let mut i = 0usize;
1021 let mut at_stmt_start = true;
1022 let mut paren = 0i32;
1023
1024 while i < bytes.len() {
1025 if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
1027 while i < bytes.len() && bytes[i] != b'\n' {
1028 out.push(bytes[i] as char);
1029 i += 1;
1030 }
1031 continue;
1032 }
1033 if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
1034 let s = i;
1035 i += 2;
1036 while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') {
1037 i += 1;
1038 }
1039 if i + 1 < bytes.len() {
1040 i += 2;
1041 }
1042 out.push_str(&src[s..i]);
1043 continue;
1044 }
1045
1046 match bytes[i] {
1048 b'(' => paren += 1,
1049 b')' => paren -= 1,
1050 _ => {}
1051 }
1052
1053 if !at_stmt_start || paren != 0 || !bytes[i].is_ascii_alphabetic() && bytes[i] != b'_' {
1054 if !bytes[i].is_ascii_whitespace() {
1058 at_stmt_start = matches!(bytes[i], b';' | b'{' | b'}');
1059 }
1060 out.push(bytes[i] as char);
1061 i += 1;
1062 continue;
1063 }
1064
1065 let id_start = i;
1067 let mut p = i;
1068 while p < bytes.len() && (bytes[p].is_ascii_alphanumeric() || bytes[p] == b'_') {
1069 p += 1;
1070 }
1071 let id_end = p;
1072 let name = &src[id_start..id_end];
1073 let mut sw_len = 0usize;
1081 if p < bytes.len() && bytes[p] == b'.' && p + 1 < bytes.len() {
1082 let sw_start = p + 1;
1083 let mut q = sw_start;
1084 while q < bytes.len()
1085 && matches!(
1086 bytes[q],
1087 b'x' | b'y' | b'z' | b'w' | b'r' | b'g' | b'b' | b'a'
1088 )
1089 {
1090 q += 1;
1091 }
1092 if q > sw_start && q - sw_start <= 4 {
1093 sw_len = q - sw_start;
1094 p = q;
1095 }
1096 }
1097 while p < bytes.len() && bytes[p].is_ascii_whitespace() {
1099 p += 1;
1100 }
1101 if p >= bytes.len() {
1102 out.push_str(&src[id_start..p]);
1103 i = p;
1104 at_stmt_start = false;
1105 continue;
1106 }
1107 let (op_byte, op_len) = match bytes[p] {
1109 b'=' if bytes.get(p + 1) != Some(&b'=') => (None, 1),
1110 b'+' if bytes.get(p + 1) == Some(&b'=') => (Some(b'+'), 2),
1111 b'-' if bytes.get(p + 1) == Some(&b'=') => (Some(b'-'), 2),
1112 b'*' if bytes.get(p + 1) == Some(&b'=') => (Some(b'*'), 2),
1113 b'/' if bytes.get(p + 1) == Some(&b'=') => (Some(b'/'), 2),
1114 _ => {
1115 out.push_str(&src[id_start..p]);
1117 i = p;
1118 at_stmt_start = false;
1119 continue;
1120 }
1121 };
1122 let op_end = p + op_len;
1123 let rhs_start = op_end;
1125 let mut q = rhs_start;
1126 let mut dpar = 0i32;
1127 let mut dbr = 0i32;
1128 while q < bytes.len() {
1129 match bytes[q] {
1130 b'(' => dpar += 1,
1131 b')' => dpar -= 1,
1132 b'[' => dbr += 1,
1133 b']' => dbr -= 1,
1134 b';' if dpar == 0 && dbr == 0 => break,
1135 _ => {}
1136 }
1137 q += 1;
1138 }
1139 if q >= bytes.len() {
1140 out.push_str(&src[id_start..q]);
1141 i = q;
1142 continue;
1143 }
1144 let rhs_raw = &src[rhs_start..q];
1145 let rhs_trimmed = rhs_raw.trim();
1146
1147 let Some(base_ty) = table.lookup(name) else {
1149 out.push_str(&src[id_start..q]);
1150 i = q;
1151 at_stmt_start = false;
1152 continue;
1153 };
1154
1155 let lhs_ty = if sw_len == 0 {
1160 base_ty
1161 } else if base_ty.is_vec() {
1162 match sw_len {
1163 1 => WgslType::F32,
1164 2 => WgslType::Vec2F,
1165 3 => WgslType::Vec3F,
1166 4 => WgslType::Vec4F,
1167 _ => {
1168 out.push_str(&src[id_start..q]);
1169 i = q;
1170 at_stmt_start = false;
1171 continue;
1172 }
1173 }
1174 } else {
1175 out.push_str(&src[id_start..q]);
1176 i = q;
1177 at_stmt_start = false;
1178 continue;
1179 };
1180
1181 if !lhs_ty.is_scalar() && !lhs_ty.is_vec() {
1184 out.push_str(&src[id_start..q]);
1185 i = q;
1186 at_stmt_start = false;
1187 continue;
1188 }
1189
1190 let rhs_ty = table.infer_expr_type(rhs_trimmed);
1191
1192 out.push_str(&src[id_start..op_end]);
1194 out.push(' ');
1195
1196 let coerced = match (lhs_ty, rhs_ty) {
1205 (l, WgslType::F32) | (l, WgslType::I32)
1209 if l.is_vec() && matches!(op_byte, None | Some(b'+') | Some(b'-')) =>
1210 {
1211 Some(format!("{}({})", l.wgsl_name(), rhs_trimmed))
1212 }
1213 (l, r) if l.is_vec() && r.is_vec() && vec_size(r) > vec_size(l) => {
1215 let sw = match vec_size(l) {
1216 2 => ".xy",
1217 3 => ".xyz",
1218 _ => "",
1219 };
1220 Some(format!("({}){}", rhs_trimmed, sw))
1221 }
1222 (l, r) if l.is_scalar() && r.is_vec() => Some(format!("({}).x", rhs_trimmed)),
1226 _ => None,
1227 };
1228
1229 if let Some(new_rhs) = coerced {
1230 out.push_str(&new_rhs);
1231 } else {
1232 out.push_str(rhs_raw);
1233 }
1234 out.push(';');
1235 i = q + 1;
1236 at_stmt_start = true;
1237 }
1238
1239 out
1240}
1241
1242pub fn inject_truncations(src: &str, table: &SymbolTable) -> String {
1262 let bytes = src.as_bytes();
1263 let mut out = String::with_capacity(src.len() + 64);
1264 let mut i = 0;
1265
1266 while i < bytes.len() {
1267 if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'/' {
1269 while i < bytes.len() && bytes[i] != b'\n' {
1270 out.push(bytes[i] as char);
1271 i += 1;
1272 }
1273 continue;
1274 }
1275
1276 let kw_len = match keyword_at(bytes, i, &["var", "let"]) {
1277 Some(n) => n,
1278 None => {
1279 out.push(bytes[i] as char);
1280 i += 1;
1281 continue;
1282 }
1283 };
1284
1285 let kw_end = i + kw_len;
1288 let mut p = kw_end;
1289 while p < bytes.len() && bytes[p].is_ascii_whitespace() {
1290 p += 1;
1291 }
1292 let name_start = p;
1293 while p < bytes.len() && (bytes[p].is_ascii_alphanumeric() || bytes[p] == b'_') {
1294 p += 1;
1295 }
1296 if p == name_start {
1297 out.push_str(&src[i..kw_end]);
1298 i = kw_end;
1299 continue;
1300 }
1301 let after_name = p;
1302 while p < bytes.len() && bytes[p].is_ascii_whitespace() {
1303 p += 1;
1304 }
1305 if p >= bytes.len() || bytes[p] != b':' {
1306 out.push_str(&src[i..after_name]);
1307 i = after_name;
1308 continue;
1309 }
1310 p += 1;
1311 let ty_start = p;
1313 while p < bytes.len() && bytes[p] != b'=' && bytes[p] != b';' {
1314 p += 1;
1315 }
1316 let ty_str = src[ty_start..p].trim();
1317 let lhs_ty = WgslType::from_decl_str(ty_str);
1318 if !matches!(
1320 lhs_ty,
1321 WgslType::F32 | WgslType::I32 | WgslType::Vec2F | WgslType::Vec3F | WgslType::Vec4F
1322 ) {
1323 out.push_str(&src[i..p]);
1324 i = p;
1325 continue;
1326 }
1327 if p >= bytes.len() || bytes[p] != b'=' {
1328 out.push_str(&src[i..p]);
1330 i = p;
1331 continue;
1332 }
1333 let eq_at = p;
1334 p += 1;
1335 let rhs_start = p;
1337 let mut depth_paren = 0i32;
1338 let mut depth_bracket = 0i32;
1339 while p < bytes.len() {
1340 match bytes[p] {
1341 b'(' => depth_paren += 1,
1342 b')' => depth_paren -= 1,
1343 b'[' => depth_bracket += 1,
1344 b']' => depth_bracket -= 1,
1345 b';' if depth_paren == 0 && depth_bracket == 0 => break,
1346 _ => {}
1347 }
1348 p += 1;
1349 }
1350 if p >= bytes.len() {
1351 out.push_str(&src[i..]);
1353 return out;
1354 }
1355 let rhs = &src[rhs_start..p];
1356 let rhs_trimmed = rhs.trim();
1357 let rhs_ty = table.infer_expr_type(rhs_trimmed);
1358
1359 out.push_str(&src[i..=eq_at]);
1361
1362 match (lhs_ty, rhs_ty) {
1363 (l, r) if l.is_scalar() && r.is_vec() => {
1365 out.push_str(" (");
1366 out.push_str(rhs_trimmed);
1367 out.push_str(").x");
1368 }
1369 (l, WgslType::F32) | (l, WgslType::I32) if l.is_vec() => {
1371 out.push(' ');
1372 out.push_str(l.wgsl_name());
1373 out.push('(');
1374 out.push_str(rhs_trimmed);
1375 out.push(')');
1376 }
1377 (l, r) if l.is_vec() && r.is_vec() && vec_size(r) > vec_size(l) => {
1379 let swizzle = match vec_size(l) {
1380 2 => ".xy",
1381 3 => ".xyz",
1382 _ => "",
1383 };
1384 out.push_str(" (");
1385 out.push_str(rhs_trimmed);
1386 out.push(')');
1387 out.push_str(swizzle);
1388 }
1389 _ => {
1390 out.push_str(rhs);
1391 }
1392 }
1393 out.push(';');
1394 i = p + 1;
1395 }
1396
1397 out
1398}
1399
1400pub fn inject_swizzle_assignments(src: &str, table: &SymbolTable) -> String {
1428 use regex::Regex;
1429 use std::sync::LazyLock;
1430
1431 static SWZ_RE: LazyLock<Regex> = LazyLock::new(|| {
1432 Regex::new(
1436 r"(?m)^([\t ]*)([A-Za-z_][A-Za-z0-9_]*)\.([xyzwrgba]{2,4})\s*([+\-*/]?=)\s*([^;]+);",
1437 )
1438 .unwrap()
1439 });
1440
1441 SWZ_RE
1442 .replace_all(src, |caps: ®ex::Captures| {
1443 let indent = &caps[1];
1444 let target = &caps[2];
1445 let swizzle = &caps[3];
1446 let op = &caps[4];
1447 let rhs = caps[5].trim();
1448
1449 let swizzle_xyzw = normalise_swizzle(swizzle);
1452
1453 if !all_unique(&swizzle_xyzw) {
1456 return caps[0].to_string();
1457 }
1458
1459 let target_ty = match table.lookup(target) {
1461 Some(WgslType::Vec3F) => WgslType::Vec3F,
1462 Some(WgslType::Vec4F) => WgslType::Vec4F,
1463 _ => return caps[0].to_string(),
1464 };
1465 let target_size = vec_size(target_ty);
1466
1467 if swizzle_xyzw.len() > target_size {
1469 return caps[0].to_string();
1470 }
1471
1472 let comps = match target_size {
1474 3 => &['x', 'y', 'z'][..],
1475 4 => &['x', 'y', 'z', 'w'][..],
1476 _ => return caps[0].to_string(),
1477 };
1478 let rhs_is_scalar = table.infer_expr_type(rhs).is_scalar();
1479 let mut lane_exprs: Vec<String> = Vec::with_capacity(target_size);
1480 for &c in comps {
1481 if let Some(pos) = swizzle_xyzw.iter().position(|&s| s == c) {
1482 let lane_letter = match pos {
1487 0 => 'x',
1488 1 => 'y',
1489 2 => 'z',
1490 3 => 'w',
1491 _ => return caps[0].to_string(),
1492 };
1493 let rhs_lane = if rhs_is_scalar || swizzle_xyzw.len() == 1 {
1494 format!("({rhs})")
1496 } else {
1497 format!("({rhs}).{lane_letter}")
1498 };
1499 let new_val = match op {
1500 "=" => rhs_lane,
1501 "+=" => format!("{target}.{c} + {rhs_lane}"),
1502 "-=" => format!("{target}.{c} - {rhs_lane}"),
1503 "*=" => format!("{target}.{c} * {rhs_lane}"),
1504 "/=" => format!("{target}.{c} / {rhs_lane}"),
1505 _ => return caps[0].to_string(),
1506 };
1507 lane_exprs.push(new_val);
1508 } else {
1509 lane_exprs.push(format!("{target}.{c}"));
1511 }
1512 }
1513
1514 format!(
1515 "{indent}{target} = {ty}({args});",
1516 ty = target_ty.wgsl_name(),
1517 args = lane_exprs.join(", ")
1518 )
1519 })
1520 .to_string()
1521}
1522
1523fn normalise_swizzle(s: &str) -> Vec<char> {
1528 s.chars()
1529 .map(|c| match c {
1530 'r' => 'x',
1531 'g' => 'y',
1532 'b' => 'z',
1533 'a' => 'w',
1534 other => other,
1535 })
1536 .collect()
1537}
1538
1539fn all_unique(letters: &[char]) -> bool {
1540 let mut seen = [false; 4];
1541 for &c in letters {
1542 let idx = match c {
1543 'x' => 0,
1544 'y' => 1,
1545 'z' => 2,
1546 'w' => 3,
1547 _ => return false,
1548 };
1549 if seen[idx] {
1550 return false;
1551 }
1552 seen[idx] = true;
1553 }
1554 true
1555}
1556
1557#[cfg(test)]
1558mod tests {
1559 use super::*;
1560
1561 #[test]
1562 fn symbol_table_picks_up_var_decls() {
1563 let src = "var foo: vec3<f32> = vec3<f32>(0.0); let bar: f32 = 1.0;";
1564 let t = SymbolTable::from_source(src);
1565 assert_eq!(t.lookup("foo"), Some(WgslType::Vec3F));
1566 assert_eq!(t.lookup("bar"), Some(WgslType::F32));
1567 }
1568
1569 #[test]
1570 fn symbol_table_seeds_wrapper_locals() {
1571 let t = SymbolTable::from_source("// nothing here");
1574 assert_eq!(t.lookup("uv"), Some(WgslType::Vec2F));
1575 assert_eq!(t.lookup("texsize"), Some(WgslType::Vec4F));
1576 assert_eq!(t.lookup("q1"), Some(WgslType::F32));
1577 assert_eq!(t.lookup("M_PI_2"), Some(WgslType::F32));
1578 }
1579
1580 #[test]
1581 fn infer_expr_numeric_literal() {
1582 let t = SymbolTable::from_source("");
1583 assert_eq!(t.infer_expr_type("1.0"), WgslType::F32);
1584 assert_eq!(t.infer_expr_type("-3.14"), WgslType::F32);
1585 assert_eq!(t.infer_expr_type("0"), WgslType::F32);
1586 }
1587
1588 #[test]
1589 fn infer_expr_known_helpers() {
1590 let t = SymbolTable::from_source("");
1591 assert_eq!(t.infer_expr_type("GetPixel(uv)"), WgslType::Vec3F);
1592 assert_eq!(t.infer_expr_type("GetBlur1(uv)"), WgslType::Vec3F);
1593 assert_eq!(t.infer_expr_type("lum(ret)"), WgslType::F32);
1594 assert_eq!(
1595 t.infer_expr_type("textureSample(sampler_main_texture, sampler_main, uv)"),
1596 WgslType::Vec4F
1597 );
1598 }
1599
1600 #[test]
1601 fn infer_expr_constructor() {
1602 let t = SymbolTable::from_source("");
1603 assert_eq!(t.infer_expr_type("vec3<f32>(1, 0, 0)"), WgslType::Vec3F);
1604 assert_eq!(t.infer_expr_type("vec4<f32>(c)"), WgslType::Vec4F);
1605 }
1606
1607 #[test]
1608 fn infer_expr_swizzle_narrows_to_scalar() {
1609 let t = SymbolTable::from_source("var c: vec4<f32> = vec4<f32>(1);");
1610 assert_eq!(t.infer_expr_type("c.x"), WgslType::F32);
1611 assert_eq!(t.infer_expr_type("c.xy"), WgslType::Vec2F);
1612 assert_eq!(t.infer_expr_type("c.xyz"), WgslType::Vec3F);
1613 }
1614
1615 #[test]
1616 fn infer_expr_binop_widens_to_vec() {
1617 let t = SymbolTable::from_source("var c: vec4<f32> = vec4<f32>(1);");
1618 assert_eq!(t.infer_expr_type("GetPixel(uv) * c.x"), WgslType::Vec3F);
1620 }
1621
1622 #[test]
1623 fn broadcast_clamp_with_scalar_bounds() {
1624 let t = SymbolTable::from_source("");
1625 let src = "ret = clamp(GetBlur1(uv), 0.0, 1.0);";
1626 let out = inject_broadcasts(src, &t);
1627 assert!(
1628 out.contains("clamp(GetBlur1(uv), vec3<f32>(0.0), vec3<f32>(1.0))"),
1629 "got: {out}"
1630 );
1631 }
1632
1633 #[test]
1634 fn broadcast_pow_with_scalar_exponent() {
1635 let t = SymbolTable::from_source("");
1636 let src = "ret = pow(GetPixel(uv), 0.5);";
1637 let out = inject_broadcasts(src, &t);
1638 assert!(
1639 out.contains("pow(GetPixel(uv), vec3<f32>(0.5))"),
1640 "got: {out}"
1641 );
1642 }
1643
1644 #[test]
1645 fn broadcast_mix_with_scalar_lerp_factor() {
1646 let t = SymbolTable::from_source("");
1647 let src = "ret = mix(a, b, 0.3);";
1648 let out = inject_broadcasts(src, &t);
1651 assert_eq!(out, src);
1652
1653 let t = SymbolTable::from_source(
1655 "var a: vec3<f32> = vec3<f32>(0); var b: vec3<f32> = vec3<f32>(1);",
1656 );
1657 let out = inject_broadcasts(src, &t);
1658 assert!(out.contains("mix(a, b, vec3<f32>(0.3))"), "got: {out}");
1659 }
1660
1661 #[test]
1662 fn broadcast_truncates_larger_vec_to_smaller() {
1663 let t = SymbolTable::from_source("var ret: vec3<f32> = vec3<f32>(0);");
1666 let src = "ret = max(ret, textureSample(t, s, uv));";
1667 let out = inject_broadcasts(src, &t);
1668 assert!(
1669 out.contains("max(ret, (textureSample(t, s, uv)).xyz)"),
1670 "got: {out}"
1671 );
1672 }
1673
1674 #[test]
1675 fn broadcast_skipped_when_all_args_scalar() {
1676 let t = SymbolTable::from_source("");
1677 let src = "var x: f32 = clamp(0.5, 0.0, 1.0);";
1678 let out = inject_broadcasts(src, &t);
1679 assert_eq!(out, src);
1680 }
1681
1682 #[test]
1683 fn truncation_f32_eq_vec3_inserts_dot_x() {
1684 let t = SymbolTable::from_source("");
1685 let src = "var gx1: f32 = GetPixel(uv) + GetBlur1(uv);";
1686 let out = inject_truncations(src, &t);
1687 assert!(
1688 out.contains("var gx1: f32 = (GetPixel(uv) + GetBlur1(uv)).x;"),
1689 "got: {out}"
1690 );
1691 }
1692
1693 #[test]
1694 fn truncation_skipped_when_rhs_already_scalar() {
1695 let t = SymbolTable::from_source("");
1696 let src = "var x: f32 = 1.0 + 2.0;";
1697 let out = inject_truncations(src, &t);
1698 assert_eq!(out, src);
1699 }
1700
1701 #[test]
1702 fn truncation_skipped_when_lhs_is_vec() {
1703 let t = SymbolTable::from_source("");
1704 let src = "var v: vec3<f32> = GetPixel(uv);";
1705 let out = inject_truncations(src, &t);
1706 assert_eq!(out, src);
1707 }
1708
1709 #[test]
1710 fn broadcast_vec3_eq_scalar_wraps_in_constructor() {
1711 let t = SymbolTable::from_source("var dz: vec3<f32> = vec3<f32>(0);");
1714 let src = "var bg: vec3<f32> = pow(length(dz), 0.7)*2 + GetBlur1(uv).y*2;";
1715 let out = inject_truncations(src, &t);
1716 assert!(
1717 out.contains(
1718 "var bg: vec3<f32> = vec3<f32>(pow(length(dz), 0.7)*2 + GetBlur1(uv).y*2);"
1719 ),
1720 "got: {out}"
1721 );
1722 }
1723
1724 #[test]
1725 fn truncation_vec3_eq_vec4_appends_xyz() {
1726 let t = SymbolTable::from_source("");
1729 let src = "var ret2: vec3<f32> = textureSample(t, s, uv);";
1730 let out = inject_truncations(src, &t);
1731 assert!(
1732 out.contains("var ret2: vec3<f32> = (textureSample(t, s, uv)).xyz;"),
1733 "got: {out}"
1734 );
1735 }
1736
1737 #[test]
1738 fn truncation_vec2_eq_vec4_appends_xy() {
1739 let t = SymbolTable::from_source("");
1740 let src = "var sam: vec2<f32> = textureSample(t, s, uv);";
1741 let out = inject_truncations(src, &t);
1742 assert!(
1743 out.contains("var sam: vec2<f32> = (textureSample(t, s, uv)).xy;"),
1744 "got: {out}"
1745 );
1746 }
1747
1748 #[test]
1749 fn broadcast_skipped_when_rhs_already_vec() {
1750 let t = SymbolTable::from_source("");
1751 let src = "var v: vec3<f32> = GetPixel(uv);";
1752 let out = inject_truncations(src, &t);
1753 assert_eq!(out, src);
1754 }
1755
1756 #[test]
1757 fn truncation_skipped_on_unknown_rhs_type() {
1758 let t = SymbolTable::from_source("");
1759 let src = "var x: f32 = some_user_function(uv);";
1762 let out = inject_truncations(src, &t);
1763 assert_eq!(out, src);
1764 }
1765
1766 #[test]
1771 fn swizzle_xy_assignment_on_vec3_target() {
1772 let t = SymbolTable::from_source("");
1776 let src = "ret.xy = diff;";
1777 let out = inject_swizzle_assignments(src, &t);
1778 assert_eq!(
1779 out, "ret = vec3<f32>((diff).x, (diff).y, ret.z);",
1780 "got: {out}"
1781 );
1782 }
1783
1784 #[test]
1785 fn swizzle_xyz_full_replace_on_vec3_target() {
1786 let t = SymbolTable::from_source("");
1787 let src = "ret.xyz = tex2D(s, uv).xyz;";
1788 let out = inject_swizzle_assignments(src, &t);
1789 assert_eq!(
1792 out,
1793 "ret = vec3<f32>((tex2D(s, uv).xyz).x, (tex2D(s, uv).xyz).y, (tex2D(s, uv).xyz).z);",
1794 );
1795 }
1796
1797 #[test]
1798 fn swizzle_compound_mul_assign() {
1799 let t = SymbolTable::from_source("");
1800 let src = "ret.xy *= diff;";
1801 let out = inject_swizzle_assignments(src, &t);
1802 assert_eq!(
1804 out,
1805 "ret = vec3<f32>(ret.x * (diff).x, ret.y * (diff).y, ret.z);",
1806 );
1807 }
1808
1809 #[test]
1810 fn swizzle_reordered_zy() {
1811 let t = SymbolTable::from_source("");
1813 let src = "ret.zy = pair;";
1814 let out = inject_swizzle_assignments(src, &t);
1815 assert_eq!(out, "ret = vec3<f32>(ret.x, (pair).y, (pair).x);",);
1817 }
1818
1819 #[test]
1820 fn swizzle_rgba_normalised_to_xyzw() {
1821 let t = SymbolTable::from_source("");
1822 let src = "ret.rg = pair;";
1823 let out = inject_swizzle_assignments(src, &t);
1824 assert_eq!(out, "ret = vec3<f32>((pair).x, (pair).y, ret.z);",);
1826 }
1827
1828 #[test]
1829 fn swizzle_skipped_when_target_unknown() {
1830 let t = SymbolTable::from_source("");
1832 let src = "user_var.xy = stuff;";
1833 let out = inject_swizzle_assignments(src, &t);
1834 assert_eq!(out, src);
1835 }
1836
1837 #[test]
1838 fn swizzle_skipped_on_single_component() {
1839 let t = SymbolTable::from_source("");
1842 let src = "ret.z = stuff;";
1843 let out = inject_swizzle_assignments(src, &t);
1844 assert_eq!(out, src);
1845 }
1846
1847 #[test]
1848 fn swizzle_skipped_when_duplicate_components() {
1849 let t = SymbolTable::from_source("");
1851 let src = "ret.xx = pair;";
1852 let out = inject_swizzle_assignments(src, &t);
1853 assert_eq!(out, src);
1854 }
1855
1856 #[test]
1857 fn swizzle_div_compound_on_vec3() {
1858 let t = SymbolTable::from_source("");
1859 let src = "ret.zy /= diff2;";
1860 let out = inject_swizzle_assignments(src, &t);
1861 assert_eq!(
1863 out,
1864 "ret = vec3<f32>(ret.x, ret.y / (diff2).y, ret.z / (diff2).x);",
1865 );
1866 }
1867
1868 #[test]
1869 fn swizzle_xy_on_vec4_target() {
1870 let t = SymbolTable::from_source("var ret4: vec4<f32> = vec4<f32>(0.0);");
1872 let src = "ret4.xy = pair;";
1873 let out = inject_swizzle_assignments(src, &t);
1874 assert_eq!(out, "ret4 = vec4<f32>((pair).x, (pair).y, ret4.z, ret4.w);",);
1875 }
1876}