1use std::fmt;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct Span {
25 pub start: u32,
26 pub end: u32,
27 pub line: u32,
28 pub col: u32,
29}
30
31impl Span {
32 pub fn dummy() -> Self {
33 Self {
34 start: 0,
35 end: 0,
36 line: 0,
37 col: 0,
38 }
39 }
40
41 pub fn merge(self, other: Span) -> Span {
45 Span {
46 start: self.start.min(other.start),
47 end: self.end.max(other.end),
48 line: self.line.min(other.line),
49 col: if self.start <= other.start {
50 self.col
51 } else {
52 other.col
53 },
54 }
55 }
56}
57
58#[derive(Debug, Clone, PartialEq)]
63pub enum TokenKind {
64 IntLit(i64),
70 FloatLit(f64),
73 BoolLit(bool),
75
76 Ident,
81 Keyword(Keyword),
85
86 LParen, RParen, LBrace, RBrace, LBracket, RBracket, Semi, Comma, Dot, Question, Colon, Plus,
101 Minus,
102 Star,
103 Slash,
104 Percent,
105 Eq,
106 EqEq,
107 BangEq,
108 Lt,
109 LtEq,
110 Gt,
111 GtEq,
112 PlusEq,
113 MinusEq,
114 StarEq,
115 SlashEq,
116 PercentEq,
117 Amp,
118 AmpAmp,
119 Bar,
120 BarBar,
121 Caret,
122 Tilde,
123 Bang,
124 AmpEq,
125 BarEq,
126 CaretEq,
127 Shl,
128 Shr,
129 ShlEq,
130 ShrEq,
131 PlusPlus,
132 MinusMinus,
133
134 Eof,
137}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub enum Keyword {
144 If,
145 Else,
146 While,
147 For,
148 Do,
149 Return,
150 Break,
151 Continue,
152 Static,
153 Const,
154 In,
155 Out,
156 InOut,
157 Struct,
158 Sampler,
159 Void,
160 ShaderBody,
164}
165
166impl fmt::Display for Keyword {
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 let s = match self {
169 Keyword::If => "if",
170 Keyword::Else => "else",
171 Keyword::While => "while",
172 Keyword::For => "for",
173 Keyword::Do => "do",
174 Keyword::Return => "return",
175 Keyword::Break => "break",
176 Keyword::Continue => "continue",
177 Keyword::Static => "static",
178 Keyword::Const => "const",
179 Keyword::In => "in",
180 Keyword::Out => "out",
181 Keyword::InOut => "inout",
182 Keyword::Struct => "struct",
183 Keyword::Sampler => "sampler",
184 Keyword::Void => "void",
185 Keyword::ShaderBody => "shader_body",
186 };
187 f.write_str(s)
188 }
189}
190
191#[derive(Debug, Clone, PartialEq)]
194pub struct Token {
195 pub kind: TokenKind,
196 pub span: Span,
197}
198
199#[derive(Debug, Clone, PartialEq)]
202pub struct LexError {
203 pub message: String,
204 pub line: u32,
205 pub col: u32,
206 pub offset: u32,
207}
208
209impl fmt::Display for LexError {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 write!(
212 f,
213 "lex error at {}:{}: {}",
214 self.line, self.col, self.message
215 )
216 }
217}
218
219impl std::error::Error for LexError {}
220
221pub fn tokenize(src: &str) -> Result<Vec<Token>, LexError> {
228 let bytes = src.as_bytes();
229 let mut tokens = Vec::with_capacity(bytes.len() / 4);
230 let mut i = 0usize;
231 let mut line = 1u32;
232 let mut line_start = 0usize;
233
234 while i < bytes.len() {
235 let b = bytes[i];
236
237 if b == b'\n' {
239 line += 1;
240 line_start = i + 1;
241 i += 1;
242 continue;
243 }
244 if b.is_ascii_whitespace() {
245 i += 1;
246 continue;
247 }
248
249 if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
251 while i < bytes.len() && bytes[i] != b'\n' {
252 i += 1;
253 }
254 continue;
255 }
256
257 if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
259 let start_line = line;
260 let start_col = (i - line_start + 1) as u32;
261 i += 2;
262 loop {
263 if i + 1 >= bytes.len() {
264 return Err(LexError {
265 message: "unterminated block comment".to_string(),
266 line: start_line,
267 col: start_col,
268 offset: i as u32,
269 });
270 }
271 if bytes[i] == b'*' && bytes[i + 1] == b'/' {
272 i += 2;
273 break;
274 }
275 if bytes[i] == b'\n' {
276 line += 1;
277 line_start = i + 1;
278 }
279 i += 1;
280 }
281 continue;
282 }
283
284 if b == b'#' {
289 while i < bytes.len() && bytes[i] != b'\n' {
290 i += 1;
291 }
292 continue;
293 }
294
295 if b.is_ascii_digit() || (b == b'.' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit())
299 {
300 let start = i;
301 let col = (i - line_start + 1) as u32;
302 let (kind, end) = lex_number(bytes, i).map_err(|msg| LexError {
303 message: msg,
304 line,
305 col,
306 offset: start as u32,
307 })?;
308 tokens.push(Token {
309 kind,
310 span: Span {
311 start: start as u32,
312 end: end as u32,
313 line,
314 col,
315 },
316 });
317 i = end;
318 continue;
319 }
320
321 if b.is_ascii_alphabetic() || b == b'_' {
323 let start = i;
324 let col = (i - line_start + 1) as u32;
325 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
326 i += 1;
327 }
328 let lex = &src[start..i];
329 let kind = match classify_ident(lex) {
330 Some(k) => k,
331 None => TokenKind::Ident,
332 };
333 tokens.push(Token {
334 kind,
335 span: Span {
336 start: start as u32,
337 end: i as u32,
338 line,
339 col,
340 },
341 });
342 continue;
343 }
344
345 let start = i;
347 let col = (i - line_start + 1) as u32;
348 let two = if i + 1 < bytes.len() {
349 Some(bytes[i + 1])
350 } else {
351 None
352 };
353 let three = if i + 2 < bytes.len() {
354 Some(bytes[i + 2])
355 } else {
356 None
357 };
358 let (kind, consumed) = match (b, two, three) {
359 (b'<', Some(b'<'), Some(b'=')) => (TokenKind::ShlEq, 3),
360 (b'>', Some(b'>'), Some(b'=')) => (TokenKind::ShrEq, 3),
361 (b'<', Some(b'<'), _) => (TokenKind::Shl, 2),
362 (b'>', Some(b'>'), _) => (TokenKind::Shr, 2),
363 (b'=', Some(b'='), _) => (TokenKind::EqEq, 2),
364 (b'!', Some(b'='), _) => (TokenKind::BangEq, 2),
365 (b'<', Some(b'='), _) => (TokenKind::LtEq, 2),
366 (b'>', Some(b'='), _) => (TokenKind::GtEq, 2),
367 (b'+', Some(b'+'), _) => (TokenKind::PlusPlus, 2),
368 (b'-', Some(b'-'), _) => (TokenKind::MinusMinus, 2),
369 (b'&', Some(b'&'), _) => (TokenKind::AmpAmp, 2),
370 (b'|', Some(b'|'), _) => (TokenKind::BarBar, 2),
371 (b'+', Some(b'='), _) => (TokenKind::PlusEq, 2),
372 (b'-', Some(b'='), _) => (TokenKind::MinusEq, 2),
373 (b'*', Some(b'='), _) => (TokenKind::StarEq, 2),
374 (b'/', Some(b'='), _) => (TokenKind::SlashEq, 2),
375 (b'%', Some(b'='), _) => (TokenKind::PercentEq, 2),
376 (b'&', Some(b'='), _) => (TokenKind::AmpEq, 2),
377 (b'|', Some(b'='), _) => (TokenKind::BarEq, 2),
378 (b'^', Some(b'='), _) => (TokenKind::CaretEq, 2),
379 (b'(', _, _) => (TokenKind::LParen, 1),
380 (b')', _, _) => (TokenKind::RParen, 1),
381 (b'{', _, _) => (TokenKind::LBrace, 1),
382 (b'}', _, _) => (TokenKind::RBrace, 1),
383 (b'[', _, _) => (TokenKind::LBracket, 1),
384 (b']', _, _) => (TokenKind::RBracket, 1),
385 (b';', _, _) => (TokenKind::Semi, 1),
386 (b',', _, _) => (TokenKind::Comma, 1),
387 (b'.', _, _) => (TokenKind::Dot, 1),
388 (b'?', _, _) => (TokenKind::Question, 1),
389 (b':', _, _) => (TokenKind::Colon, 1),
390 (b'+', _, _) => (TokenKind::Plus, 1),
391 (b'-', _, _) => (TokenKind::Minus, 1),
392 (b'*', _, _) => (TokenKind::Star, 1),
393 (b'/', _, _) => (TokenKind::Slash, 1),
394 (b'%', _, _) => (TokenKind::Percent, 1),
395 (b'=', _, _) => (TokenKind::Eq, 1),
396 (b'<', _, _) => (TokenKind::Lt, 1),
397 (b'>', _, _) => (TokenKind::Gt, 1),
398 (b'&', _, _) => (TokenKind::Amp, 1),
399 (b'|', _, _) => (TokenKind::Bar, 1),
400 (b'^', _, _) => (TokenKind::Caret, 1),
401 (b'~', _, _) => (TokenKind::Tilde, 1),
402 (b'!', _, _) => (TokenKind::Bang, 1),
403 _ => {
404 return Err(LexError {
405 message: format!("unexpected character '{}' (0x{:02x})", b as char, b),
406 line,
407 col,
408 offset: start as u32,
409 });
410 }
411 };
412 i += consumed;
413 tokens.push(Token {
414 kind,
415 span: Span {
416 start: start as u32,
417 end: i as u32,
418 line,
419 col,
420 },
421 });
422 }
423
424 tokens.push(Token {
425 kind: TokenKind::Eof,
426 span: Span {
427 start: bytes.len() as u32,
428 end: bytes.len() as u32,
429 line,
430 col: (bytes.len() - line_start + 1) as u32,
431 },
432 });
433
434 Ok(tokens)
435}
436
437fn classify_ident(s: &str) -> Option<TokenKind> {
443 let kw = match s {
444 "if" => Keyword::If,
445 "else" => Keyword::Else,
446 "while" => Keyword::While,
447 "for" => Keyword::For,
448 "do" => Keyword::Do,
449 "return" => Keyword::Return,
450 "break" => Keyword::Break,
451 "continue" => Keyword::Continue,
452 "static" => Keyword::Static,
453 "const" => Keyword::Const,
454 "in" => Keyword::In,
455 "out" => Keyword::Out,
456 "inout" => Keyword::InOut,
457 "struct" => Keyword::Struct,
458 "sampler" | "sampler2D" | "sampler3D" => Keyword::Sampler,
459 "void" => Keyword::Void,
460 "shader_body" => Keyword::ShaderBody,
461 "true" => return Some(TokenKind::BoolLit(true)),
462 "false" => return Some(TokenKind::BoolLit(false)),
463 _ => return None,
464 };
465 Some(TokenKind::Keyword(kw))
466}
467
468fn lex_number(bytes: &[u8], start: usize) -> Result<(TokenKind, usize), String> {
476 let mut i = start;
477 let mut saw_dot = false;
478 let mut saw_exp = false;
479
480 if bytes[i] == b'0' && i + 1 < bytes.len() && (bytes[i + 1] == b'x' || bytes[i + 1] == b'X') {
482 i += 2;
483 let hex_start = i;
484 while i < bytes.len() && bytes[i].is_ascii_hexdigit() {
485 i += 1;
486 }
487 if i == hex_start {
488 return Err("expected hex digits after `0x`".to_string());
489 }
490 let lex = std::str::from_utf8(&bytes[hex_start..i]).map_err(|e| e.to_string())?;
491 let v = i64::from_str_radix(lex, 16).map_err(|e| e.to_string())?;
492 while i < bytes.len() && matches!(bytes[i], b'u' | b'U' | b'l' | b'L') {
494 i += 1;
495 }
496 return Ok((TokenKind::IntLit(v), i));
497 }
498
499 while i < bytes.len() && bytes[i].is_ascii_digit() {
501 i += 1;
502 }
503 if i < bytes.len() && bytes[i] == b'.' {
505 saw_dot = true;
509 i += 1;
510 while i < bytes.len() && bytes[i].is_ascii_digit() {
511 i += 1;
512 }
513 }
514 if i < bytes.len() && (bytes[i] == b'e' || bytes[i] == b'E') {
516 saw_exp = true;
517 i += 1;
518 if i < bytes.len() && (bytes[i] == b'+' || bytes[i] == b'-') {
519 i += 1;
520 }
521 let exp_start = i;
522 while i < bytes.len() && bytes[i].is_ascii_digit() {
523 i += 1;
524 }
525 if i == exp_start {
526 return Err("expected digits in exponent".to_string());
527 }
528 }
529 let is_float_suffix = i < bytes.len() && matches!(bytes[i], b'f' | b'F' | b'h' | b'H');
531 let is_int_suffix = i < bytes.len() && matches!(bytes[i], b'u' | b'U' | b'l' | b'L');
532 if is_float_suffix {
533 i += 1;
534 saw_dot = true;
537 } else if is_int_suffix {
538 i += 1;
539 }
540
541 let raw = std::str::from_utf8(&bytes[start..i]).map_err(|e| e.to_string())?;
542 let stripped = if is_float_suffix || is_int_suffix {
544 &raw[..raw.len() - 1]
545 } else {
546 raw
547 };
548
549 if saw_dot || saw_exp {
550 let v: f64 = stripped
551 .parse()
552 .map_err(|e: std::num::ParseFloatError| e.to_string())?;
553 Ok((TokenKind::FloatLit(v), i))
554 } else {
555 let v: i64 = stripped
558 .parse()
559 .map_err(|e: std::num::ParseIntError| e.to_string())?;
560 Ok((TokenKind::IntLit(v), i))
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 fn kinds(src: &str) -> Vec<TokenKind> {
569 tokenize(src).unwrap().into_iter().map(|t| t.kind).collect()
570 }
571
572 #[test]
573 fn empty_input_yields_eof() {
574 assert_eq!(kinds(""), vec![TokenKind::Eof]);
575 }
576
577 #[test]
578 fn integer_literals() {
579 assert_eq!(
580 kinds("0 1 12 02 0xff 0x1F"),
581 vec![
582 TokenKind::IntLit(0),
583 TokenKind::IntLit(1),
584 TokenKind::IntLit(12),
585 TokenKind::IntLit(2),
586 TokenKind::IntLit(0xff),
587 TokenKind::IntLit(0x1F),
588 TokenKind::Eof,
589 ]
590 );
591 }
592
593 #[test]
594 fn float_literals_with_suffixes() {
595 let k = kinds("1.5 .5 5. 1e2 1.5e-3 1.5f 0.5h");
596 assert_eq!(k.len(), 8);
597 assert!(matches!(k[0], TokenKind::FloatLit(v) if (v - 1.5).abs() < 1e-9));
598 assert!(matches!(k[1], TokenKind::FloatLit(v) if (v - 0.5).abs() < 1e-9));
599 assert!(matches!(k[2], TokenKind::FloatLit(v) if (v - 5.0).abs() < 1e-9));
600 assert!(matches!(k[3], TokenKind::FloatLit(v) if (v - 100.0).abs() < 1e-9));
601 assert!(matches!(k[4], TokenKind::FloatLit(v) if (v - 0.0015).abs() < 1e-12));
602 assert!(matches!(k[5], TokenKind::FloatLit(v) if (v - 1.5).abs() < 1e-9));
603 assert!(matches!(k[6], TokenKind::FloatLit(v) if (v - 0.5).abs() < 1e-9));
604 }
605
606 #[test]
607 fn identifier_and_keyword_split() {
608 assert_eq!(
612 kinds("if while float2 my_var"),
613 vec![
614 TokenKind::Keyword(Keyword::If),
615 TokenKind::Keyword(Keyword::While),
616 TokenKind::Ident,
617 TokenKind::Ident,
618 TokenKind::Eof,
619 ]
620 );
621 }
622
623 #[test]
624 fn bool_literals() {
625 assert_eq!(
626 kinds("true false"),
627 vec![
628 TokenKind::BoolLit(true),
629 TokenKind::BoolLit(false),
630 TokenKind::Eof
631 ]
632 );
633 }
634
635 #[test]
636 fn line_and_block_comments_skipped() {
637 let k = kinds("a // comment\nb /* multi\nline */ c");
638 assert_eq!(k.len(), 4); assert_eq!(k[0], TokenKind::Ident);
640 assert_eq!(k[1], TokenKind::Ident);
641 assert_eq!(k[2], TokenKind::Ident);
642 }
643
644 #[test]
645 fn preprocessor_lines_skipped() {
646 assert_eq!(
647 kinds("#define X 1\nfoo"),
648 vec![TokenKind::Ident, TokenKind::Eof]
649 );
650 }
651
652 #[test]
653 fn compound_operators() {
654 assert_eq!(
655 kinds("+= -= *= /= %= == != <= >= && || ++ -- << >>"),
656 vec![
657 TokenKind::PlusEq,
658 TokenKind::MinusEq,
659 TokenKind::StarEq,
660 TokenKind::SlashEq,
661 TokenKind::PercentEq,
662 TokenKind::EqEq,
663 TokenKind::BangEq,
664 TokenKind::LtEq,
665 TokenKind::GtEq,
666 TokenKind::AmpAmp,
667 TokenKind::BarBar,
668 TokenKind::PlusPlus,
669 TokenKind::MinusMinus,
670 TokenKind::Shl,
671 TokenKind::Shr,
672 TokenKind::Eof,
673 ]
674 );
675 }
676
677 #[test]
678 fn single_operators() {
679 assert_eq!(
680 kinds("+ - * / % = < > & | ^ ~ ! ? : ;"),
681 vec![
682 TokenKind::Plus,
683 TokenKind::Minus,
684 TokenKind::Star,
685 TokenKind::Slash,
686 TokenKind::Percent,
687 TokenKind::Eq,
688 TokenKind::Lt,
689 TokenKind::Gt,
690 TokenKind::Amp,
691 TokenKind::Bar,
692 TokenKind::Caret,
693 TokenKind::Tilde,
694 TokenKind::Bang,
695 TokenKind::Question,
696 TokenKind::Colon,
697 TokenKind::Semi,
698 TokenKind::Eof,
699 ]
700 );
701 }
702
703 #[test]
704 fn punctuation() {
705 assert_eq!(
706 kinds("( ) { } [ ] , ."),
707 vec![
708 TokenKind::LParen,
709 TokenKind::RParen,
710 TokenKind::LBrace,
711 TokenKind::RBrace,
712 TokenKind::LBracket,
713 TokenKind::RBracket,
714 TokenKind::Comma,
715 TokenKind::Dot,
716 TokenKind::Eof,
717 ]
718 );
719 }
720
721 #[test]
722 fn spans_track_line_and_column() {
723 let src = "abc\n xyz";
724 let toks = tokenize(src).unwrap();
725 assert_eq!(toks.len(), 3);
726 assert_eq!(toks[0].span.line, 1);
727 assert_eq!(toks[0].span.col, 1);
728 assert_eq!(toks[1].span.line, 2);
729 assert_eq!(toks[1].span.col, 3);
730 assert_eq!(
731 &src[toks[0].span.start as usize..toks[0].span.end as usize],
732 "abc"
733 );
734 assert_eq!(
735 &src[toks[1].span.start as usize..toks[1].span.end as usize],
736 "xyz"
737 );
738 }
739
740 #[test]
741 fn unterminated_block_comment_errors() {
742 let err = tokenize("a /* never closed").unwrap_err();
743 assert!(err.message.contains("unterminated block comment"));
744 }
745
746 #[test]
747 fn unexpected_char_errors() {
748 let err = tokenize("a @ b").unwrap_err();
749 assert!(err.message.contains("unexpected character"));
750 }
751
752 #[test]
753 fn md2_sample_warp_body_tokenises_cleanly() {
754 let src = r#"
757 shader_body {
758 float2 uv2 = uv - 0.5;
759 uv2 *= aspect.xy;
760 float dist = length(uv2);
761 ret += tex2D(sampler_main, uv2) * 0.5;
762 }
763 "#;
764 let toks = tokenize(src).unwrap();
765 assert!(
767 toks.iter()
768 .any(|t| matches!(t.kind, TokenKind::Keyword(Keyword::ShaderBody)))
769 );
770 let star_eqs = toks.iter().filter(|t| t.kind == TokenKind::StarEq).count();
772 assert_eq!(star_eqs, 1);
773 assert!(
775 toks.iter()
776 .any(|t| matches!(t.kind, TokenKind::FloatLit(v) if (v - 0.5).abs() < 1e-9))
777 );
778 }
779
780 #[test]
781 fn leading_zero_decimal_parses_as_value() {
782 assert_eq!(kinds("02"), vec![TokenKind::IntLit(2), TokenKind::Eof]);
785 }
786
787 #[test]
788 fn matrix_type_kept_as_identifier() {
789 let toks = tokenize("float2x2 m;").unwrap();
792 assert_eq!(toks[0].kind, TokenKind::Ident);
793 assert_eq!(
794 &"float2x2 m;"[toks[0].span.start as usize..toks[0].span.end as usize],
795 "float2x2"
796 );
797 }
798
799 #[test]
800 fn span_merge_combines_ranges() {
801 let a = Span {
802 start: 5,
803 end: 8,
804 line: 2,
805 col: 3,
806 };
807 let b = Span {
808 start: 12,
809 end: 15,
810 line: 2,
811 col: 10,
812 };
813 let m = a.merge(b);
814 assert_eq!(m.start, 5);
815 assert_eq!(m.end, 15);
816 assert_eq!(m.col, 3);
817 }
818}