onedrop_hlsl/
lex.rs

1//! HLSL tokenizer.
2//!
3//! Streaming lexer over a `&str` of HLSL source. Produces a `Vec<Token>` with
4//! position tracking (byte offset + line + column). The tokenizer is
5//! intentionally permissive: HLSL semantic quirks (e.g. typed swizzle on RHS,
6//! `float2x2` matrix constructors, leading-zero integers, `f`/`h`/`u` numeric
7//! suffixes) survive as plain tokens, and the parser disambiguates them in
8//! context. Comments and preprocessor lines are skipped — MD2 user shaders
9//! don't depend on them. Continuation back-ticks (`` ` ``) that appear in
10//! `.milk` preset files are stripped one layer above (the `.milk` parser
11//! joins comp shader lines), so the lexer never sees them.
12//!
13//! Output is consumed by [`crate::parse`] and downstream emitters. The
14//! existing regex-based pipeline in [`crate::translate_shader`] does *not*
15//! call into the lexer; it stays in place and continues to drive the current
16//! pass-rate on `test-presets-200/`.
17
18use std::fmt;
19
20/// Source span: byte offsets `[start, end)` plus 1-indexed line/column of
21/// `start`. Keeping the start position is enough for error reporting; the
22/// parser slices `&source[start..end]` to recover the original lexeme.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct Span {
25    pub start: u32,
26    pub end: u32,
27    pub line: u32,
28    pub col: u32,
29}
30
31impl Span {
32    pub fn dummy() -> Self {
33        Self {
34            start: 0,
35            end: 0,
36            line: 0,
37            col: 0,
38        }
39    }
40
41    /// Combine two spans into one covering the full range from `self.start`
42    /// to `other.end`. Used when a syntactic construct spans multiple tokens
43    /// (e.g. a binary expression's span covers LHS through RHS).
44    pub fn merge(self, other: Span) -> Span {
45        Span {
46            start: self.start.min(other.start),
47            end: self.end.max(other.end),
48            line: self.line.min(other.line),
49            col: if self.start <= other.start {
50                self.col
51            } else {
52                other.col
53            },
54        }
55    }
56}
57
58/// Token kinds produced by [`tokenize`]. Identifiers and keywords are
59/// distinguished here so the parser doesn't have to re-lookup the table for
60/// every ident. Literal values are pre-parsed into `i64` / `f64` so the
61/// parser doesn't have to re-parse the lexeme either.
62#[derive(Debug, Clone, PartialEq)]
63pub enum TokenKind {
64    // ---- Literals ----
65    /// Integer literal — value pre-parsed. HLSL accepts decimal (`12`,
66    /// leading-zero like `02`) and hex (`0x1F`); we widen to `i64`. The `u`
67    /// suffix (`12u`) is recognised but doesn't change the storage type
68    /// here — emitters decide.
69    IntLit(i64),
70    /// Float literal — value pre-parsed. HLSL accepts `1.5`, `.5`, `5.`,
71    /// `1e2`, `1.5e-3`, optional `f` / `h` suffix.
72    FloatLit(f64),
73    /// Boolean literal (`true` / `false`). HLSL uses lowercase only.
74    BoolLit(bool),
75
76    // ---- Identifiers and keywords ----
77    /// Generic identifier — type names like `float2`, `mat3x3`, sampler
78    /// names, user variables, builtins. The parser context tells us which
79    /// kind we have.
80    Ident,
81    /// `if`, `else`, `while`, `for`, `do`, `return`, `break`, `continue`,
82    /// `static`, `const`, `in`, `out`, `inout`, `struct`, `sampler`,
83    /// `void`. The variant carries the keyword identity.
84    Keyword(Keyword),
85
86    // ---- Punctuation ----
87    LParen,   // (
88    RParen,   // )
89    LBrace,   // {
90    RBrace,   // }
91    LBracket, // [
92    RBracket, // ]
93    Semi,     // ;
94    Comma,    // ,
95    Dot,      // .
96    Question, // ?
97    Colon,    // :
98
99    // ---- Operators ----
100    Plus,
101    Minus,
102    Star,
103    Slash,
104    Percent,
105    Eq,
106    EqEq,
107    BangEq,
108    Lt,
109    LtEq,
110    Gt,
111    GtEq,
112    PlusEq,
113    MinusEq,
114    StarEq,
115    SlashEq,
116    PercentEq,
117    Amp,
118    AmpAmp,
119    Bar,
120    BarBar,
121    Caret,
122    Tilde,
123    Bang,
124    AmpEq,
125    BarEq,
126    CaretEq,
127    Shl,
128    Shr,
129    ShlEq,
130    ShrEq,
131    PlusPlus,
132    MinusMinus,
133
134    /// End of input. Always the last token; the parser uses it to know it
135    /// has consumed everything.
136    Eof,
137}
138
139/// Reserved-word identity. Identifiers that match one of these are
140/// classified at lex time so the parser does a single `match` instead of a
141/// string comparison.
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub enum Keyword {
144    If,
145    Else,
146    While,
147    For,
148    Do,
149    Return,
150    Break,
151    Continue,
152    Static,
153    Const,
154    In,
155    Out,
156    InOut,
157    Struct,
158    Sampler,
159    Void,
160    /// `shader_body` — MD2-specific wrapper. The translator strips it
161    /// before parsing, but we keep the keyword so the lexer survives
162    /// un-stripped inputs in tests.
163    ShaderBody,
164}
165
166impl fmt::Display for Keyword {
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        let s = match self {
169            Keyword::If => "if",
170            Keyword::Else => "else",
171            Keyword::While => "while",
172            Keyword::For => "for",
173            Keyword::Do => "do",
174            Keyword::Return => "return",
175            Keyword::Break => "break",
176            Keyword::Continue => "continue",
177            Keyword::Static => "static",
178            Keyword::Const => "const",
179            Keyword::In => "in",
180            Keyword::Out => "out",
181            Keyword::InOut => "inout",
182            Keyword::Struct => "struct",
183            Keyword::Sampler => "sampler",
184            Keyword::Void => "void",
185            Keyword::ShaderBody => "shader_body",
186        };
187        f.write_str(s)
188    }
189}
190
191/// One token: kind plus source span. The span is enough to recover the
192/// original lexeme via `&source[span.start as usize..span.end as usize]`.
193#[derive(Debug, Clone, PartialEq)]
194pub struct Token {
195    pub kind: TokenKind,
196    pub span: Span,
197}
198
199/// Lexer error. Position is preserved so the parser can wrap it with extra
200/// context (function name, surrounding text, etc.).
201#[derive(Debug, Clone, PartialEq)]
202pub struct LexError {
203    pub message: String,
204    pub line: u32,
205    pub col: u32,
206    pub offset: u32,
207}
208
209impl fmt::Display for LexError {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        write!(
212            f,
213            "lex error at {}:{}: {}",
214            self.line, self.col, self.message
215        )
216    }
217}
218
219impl std::error::Error for LexError {}
220
221/// Tokenize a complete HLSL source string. Returns the full token stream
222/// terminated by [`TokenKind::Eof`], or the first error encountered.
223///
224/// The lexer never panics on user input — every fall-through case becomes
225/// either a real token or a [`LexError`]. Comments and `#`-prefixed
226/// preprocessor lines are skipped silently.
227pub fn tokenize(src: &str) -> Result<Vec<Token>, LexError> {
228    let bytes = src.as_bytes();
229    let mut tokens = Vec::with_capacity(bytes.len() / 4);
230    let mut i = 0usize;
231    let mut line = 1u32;
232    let mut line_start = 0usize;
233
234    while i < bytes.len() {
235        let b = bytes[i];
236
237        // ---- whitespace ----
238        if b == b'\n' {
239            line += 1;
240            line_start = i + 1;
241            i += 1;
242            continue;
243        }
244        if b.is_ascii_whitespace() {
245            i += 1;
246            continue;
247        }
248
249        // ---- line comment ----
250        if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
251            while i < bytes.len() && bytes[i] != b'\n' {
252                i += 1;
253            }
254            continue;
255        }
256
257        // ---- block comment ----
258        if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
259            let start_line = line;
260            let start_col = (i - line_start + 1) as u32;
261            i += 2;
262            loop {
263                if i + 1 >= bytes.len() {
264                    return Err(LexError {
265                        message: "unterminated block comment".to_string(),
266                        line: start_line,
267                        col: start_col,
268                        offset: i as u32,
269                    });
270                }
271                if bytes[i] == b'*' && bytes[i + 1] == b'/' {
272                    i += 2;
273                    break;
274                }
275                if bytes[i] == b'\n' {
276                    line += 1;
277                    line_start = i + 1;
278                }
279                i += 1;
280            }
281            continue;
282        }
283
284        // ---- preprocessor line ----
285        // `#define`, `#include`, `#pragma`, etc. MD2 user shaders rarely
286        // need them after the .milk parser is done; we drop them so the
287        // parser sees clean code.
288        if b == b'#' {
289            while i < bytes.len() && bytes[i] != b'\n' {
290                i += 1;
291            }
292            continue;
293        }
294
295        // ---- numeric literal ----
296        // Decide between int and float by scanning ahead for `.`, `e`/`E`,
297        // or a `f`/`h` suffix. Hex literals (`0x..`) are always int.
298        if b.is_ascii_digit() || (b == b'.' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit())
299        {
300            let start = i;
301            let col = (i - line_start + 1) as u32;
302            let (kind, end) = lex_number(bytes, i).map_err(|msg| LexError {
303                message: msg,
304                line,
305                col,
306                offset: start as u32,
307            })?;
308            tokens.push(Token {
309                kind,
310                span: Span {
311                    start: start as u32,
312                    end: end as u32,
313                    line,
314                    col,
315                },
316            });
317            i = end;
318            continue;
319        }
320
321        // ---- identifier / keyword ----
322        if b.is_ascii_alphabetic() || b == b'_' {
323            let start = i;
324            let col = (i - line_start + 1) as u32;
325            while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
326                i += 1;
327            }
328            let lex = &src[start..i];
329            let kind = match classify_ident(lex) {
330                Some(k) => k,
331                None => TokenKind::Ident,
332            };
333            tokens.push(Token {
334                kind,
335                span: Span {
336                    start: start as u32,
337                    end: i as u32,
338                    line,
339                    col,
340                },
341            });
342            continue;
343        }
344
345        // ---- punctuation / operators ----
346        let start = i;
347        let col = (i - line_start + 1) as u32;
348        let two = if i + 1 < bytes.len() {
349            Some(bytes[i + 1])
350        } else {
351            None
352        };
353        let three = if i + 2 < bytes.len() {
354            Some(bytes[i + 2])
355        } else {
356            None
357        };
358        let (kind, consumed) = match (b, two, three) {
359            (b'<', Some(b'<'), Some(b'=')) => (TokenKind::ShlEq, 3),
360            (b'>', Some(b'>'), Some(b'=')) => (TokenKind::ShrEq, 3),
361            (b'<', Some(b'<'), _) => (TokenKind::Shl, 2),
362            (b'>', Some(b'>'), _) => (TokenKind::Shr, 2),
363            (b'=', Some(b'='), _) => (TokenKind::EqEq, 2),
364            (b'!', Some(b'='), _) => (TokenKind::BangEq, 2),
365            (b'<', Some(b'='), _) => (TokenKind::LtEq, 2),
366            (b'>', Some(b'='), _) => (TokenKind::GtEq, 2),
367            (b'+', Some(b'+'), _) => (TokenKind::PlusPlus, 2),
368            (b'-', Some(b'-'), _) => (TokenKind::MinusMinus, 2),
369            (b'&', Some(b'&'), _) => (TokenKind::AmpAmp, 2),
370            (b'|', Some(b'|'), _) => (TokenKind::BarBar, 2),
371            (b'+', Some(b'='), _) => (TokenKind::PlusEq, 2),
372            (b'-', Some(b'='), _) => (TokenKind::MinusEq, 2),
373            (b'*', Some(b'='), _) => (TokenKind::StarEq, 2),
374            (b'/', Some(b'='), _) => (TokenKind::SlashEq, 2),
375            (b'%', Some(b'='), _) => (TokenKind::PercentEq, 2),
376            (b'&', Some(b'='), _) => (TokenKind::AmpEq, 2),
377            (b'|', Some(b'='), _) => (TokenKind::BarEq, 2),
378            (b'^', Some(b'='), _) => (TokenKind::CaretEq, 2),
379            (b'(', _, _) => (TokenKind::LParen, 1),
380            (b')', _, _) => (TokenKind::RParen, 1),
381            (b'{', _, _) => (TokenKind::LBrace, 1),
382            (b'}', _, _) => (TokenKind::RBrace, 1),
383            (b'[', _, _) => (TokenKind::LBracket, 1),
384            (b']', _, _) => (TokenKind::RBracket, 1),
385            (b';', _, _) => (TokenKind::Semi, 1),
386            (b',', _, _) => (TokenKind::Comma, 1),
387            (b'.', _, _) => (TokenKind::Dot, 1),
388            (b'?', _, _) => (TokenKind::Question, 1),
389            (b':', _, _) => (TokenKind::Colon, 1),
390            (b'+', _, _) => (TokenKind::Plus, 1),
391            (b'-', _, _) => (TokenKind::Minus, 1),
392            (b'*', _, _) => (TokenKind::Star, 1),
393            (b'/', _, _) => (TokenKind::Slash, 1),
394            (b'%', _, _) => (TokenKind::Percent, 1),
395            (b'=', _, _) => (TokenKind::Eq, 1),
396            (b'<', _, _) => (TokenKind::Lt, 1),
397            (b'>', _, _) => (TokenKind::Gt, 1),
398            (b'&', _, _) => (TokenKind::Amp, 1),
399            (b'|', _, _) => (TokenKind::Bar, 1),
400            (b'^', _, _) => (TokenKind::Caret, 1),
401            (b'~', _, _) => (TokenKind::Tilde, 1),
402            (b'!', _, _) => (TokenKind::Bang, 1),
403            _ => {
404                return Err(LexError {
405                    message: format!("unexpected character '{}' (0x{:02x})", b as char, b),
406                    line,
407                    col,
408                    offset: start as u32,
409                });
410            }
411        };
412        i += consumed;
413        tokens.push(Token {
414            kind,
415            span: Span {
416                start: start as u32,
417                end: i as u32,
418                line,
419                col,
420            },
421        });
422    }
423
424    tokens.push(Token {
425        kind: TokenKind::Eof,
426        span: Span {
427            start: bytes.len() as u32,
428            end: bytes.len() as u32,
429            line,
430            col: (bytes.len() - line_start + 1) as u32,
431        },
432    });
433
434    Ok(tokens)
435}
436
437/// Recognise a reserved word. Returns `None` if the lexeme is a plain
438/// identifier. Type names (`float`, `float2`, `mat3x3`, …) are intentionally
439/// *not* classified as keywords — the parser treats them as identifiers in
440/// declaration contexts so vector/matrix variants can be added without
441/// touching the lexer.
442fn classify_ident(s: &str) -> Option<TokenKind> {
443    let kw = match s {
444        "if" => Keyword::If,
445        "else" => Keyword::Else,
446        "while" => Keyword::While,
447        "for" => Keyword::For,
448        "do" => Keyword::Do,
449        "return" => Keyword::Return,
450        "break" => Keyword::Break,
451        "continue" => Keyword::Continue,
452        "static" => Keyword::Static,
453        "const" => Keyword::Const,
454        "in" => Keyword::In,
455        "out" => Keyword::Out,
456        "inout" => Keyword::InOut,
457        "struct" => Keyword::Struct,
458        "sampler" | "sampler2D" | "sampler3D" => Keyword::Sampler,
459        "void" => Keyword::Void,
460        "shader_body" => Keyword::ShaderBody,
461        "true" => return Some(TokenKind::BoolLit(true)),
462        "false" => return Some(TokenKind::BoolLit(false)),
463        _ => return None,
464    };
465    Some(TokenKind::Keyword(kw))
466}
467
468/// Lex one numeric literal starting at `i`. Returns the token kind plus the
469/// end byte offset. Accepts:
470/// - decimal int (`0`, `12`, `02`)
471/// - hex int (`0x1F`, `0xff`)
472/// - float (`1.5`, `.5`, `5.`, `1e2`, `1.5e-3`)
473/// - optional suffix `f` / `F` / `h` / `H` / `u` / `U` / `l` / `L` (consumed
474///   but ignored — the parser uses the post-parsing numeric value).
475fn lex_number(bytes: &[u8], start: usize) -> Result<(TokenKind, usize), String> {
476    let mut i = start;
477    let mut saw_dot = false;
478    let mut saw_exp = false;
479
480    // Hex literal short-circuit.
481    if bytes[i] == b'0' && i + 1 < bytes.len() && (bytes[i + 1] == b'x' || bytes[i + 1] == b'X') {
482        i += 2;
483        let hex_start = i;
484        while i < bytes.len() && bytes[i].is_ascii_hexdigit() {
485            i += 1;
486        }
487        if i == hex_start {
488            return Err("expected hex digits after `0x`".to_string());
489        }
490        let lex = std::str::from_utf8(&bytes[hex_start..i]).map_err(|e| e.to_string())?;
491        let v = i64::from_str_radix(lex, 16).map_err(|e| e.to_string())?;
492        // Consume suffix.
493        while i < bytes.len() && matches!(bytes[i], b'u' | b'U' | b'l' | b'L') {
494            i += 1;
495        }
496        return Ok((TokenKind::IntLit(v), i));
497    }
498
499    // Integer or fractional part.
500    while i < bytes.len() && bytes[i].is_ascii_digit() {
501        i += 1;
502    }
503    // Fractional part.
504    if i < bytes.len() && bytes[i] == b'.' {
505        // Only treat as float if at least one digit appears before or
506        // after the dot. `1.` and `.5` and `1.5` all qualify.
507        // Note: HLSL's `1..5` is illegal, so we don't worry about it.
508        saw_dot = true;
509        i += 1;
510        while i < bytes.len() && bytes[i].is_ascii_digit() {
511            i += 1;
512        }
513    }
514    // Exponent.
515    if i < bytes.len() && (bytes[i] == b'e' || bytes[i] == b'E') {
516        saw_exp = true;
517        i += 1;
518        if i < bytes.len() && (bytes[i] == b'+' || bytes[i] == b'-') {
519            i += 1;
520        }
521        let exp_start = i;
522        while i < bytes.len() && bytes[i].is_ascii_digit() {
523            i += 1;
524        }
525        if i == exp_start {
526            return Err("expected digits in exponent".to_string());
527        }
528    }
529    // Suffix.
530    let is_float_suffix = i < bytes.len() && matches!(bytes[i], b'f' | b'F' | b'h' | b'H');
531    let is_int_suffix = i < bytes.len() && matches!(bytes[i], b'u' | b'U' | b'l' | b'L');
532    if is_float_suffix {
533        i += 1;
534        // `f`/`h` suffix forces the literal to a float regardless of how
535        // the mantissa was written (`1f` → 1.0).
536        saw_dot = true;
537    } else if is_int_suffix {
538        i += 1;
539    }
540
541    let raw = std::str::from_utf8(&bytes[start..i]).map_err(|e| e.to_string())?;
542    // Strip the suffix character before std parse.
543    let stripped = if is_float_suffix || is_int_suffix {
544        &raw[..raw.len() - 1]
545    } else {
546        raw
547    };
548
549    if saw_dot || saw_exp {
550        let v: f64 = stripped
551            .parse()
552            .map_err(|e: std::num::ParseFloatError| e.to_string())?;
553        Ok((TokenKind::FloatLit(v), i))
554    } else {
555        // HLSL allows leading zeros (`02` == 2). std `i64::from_str` accepts
556        // them too. We accept any width that fits in i64.
557        let v: i64 = stripped
558            .parse()
559            .map_err(|e: std::num::ParseIntError| e.to_string())?;
560        Ok((TokenKind::IntLit(v), i))
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    fn kinds(src: &str) -> Vec<TokenKind> {
569        tokenize(src).unwrap().into_iter().map(|t| t.kind).collect()
570    }
571
572    #[test]
573    fn empty_input_yields_eof() {
574        assert_eq!(kinds(""), vec![TokenKind::Eof]);
575    }
576
577    #[test]
578    fn integer_literals() {
579        assert_eq!(
580            kinds("0 1 12 02 0xff 0x1F"),
581            vec![
582                TokenKind::IntLit(0),
583                TokenKind::IntLit(1),
584                TokenKind::IntLit(12),
585                TokenKind::IntLit(2),
586                TokenKind::IntLit(0xff),
587                TokenKind::IntLit(0x1F),
588                TokenKind::Eof,
589            ]
590        );
591    }
592
593    #[test]
594    fn float_literals_with_suffixes() {
595        let k = kinds("1.5 .5 5. 1e2 1.5e-3 1.5f 0.5h");
596        assert_eq!(k.len(), 8);
597        assert!(matches!(k[0], TokenKind::FloatLit(v) if (v - 1.5).abs() < 1e-9));
598        assert!(matches!(k[1], TokenKind::FloatLit(v) if (v - 0.5).abs() < 1e-9));
599        assert!(matches!(k[2], TokenKind::FloatLit(v) if (v - 5.0).abs() < 1e-9));
600        assert!(matches!(k[3], TokenKind::FloatLit(v) if (v - 100.0).abs() < 1e-9));
601        assert!(matches!(k[4], TokenKind::FloatLit(v) if (v - 0.0015).abs() < 1e-12));
602        assert!(matches!(k[5], TokenKind::FloatLit(v) if (v - 1.5).abs() < 1e-9));
603        assert!(matches!(k[6], TokenKind::FloatLit(v) if (v - 0.5).abs() < 1e-9));
604    }
605
606    #[test]
607    fn identifier_and_keyword_split() {
608        // `float2` is *not* a keyword — it stays as an identifier so the
609        // parser can recognise it in declaration context alongside future
610        // matrix names without touching the lexer.
611        assert_eq!(
612            kinds("if while float2 my_var"),
613            vec![
614                TokenKind::Keyword(Keyword::If),
615                TokenKind::Keyword(Keyword::While),
616                TokenKind::Ident,
617                TokenKind::Ident,
618                TokenKind::Eof,
619            ]
620        );
621    }
622
623    #[test]
624    fn bool_literals() {
625        assert_eq!(
626            kinds("true false"),
627            vec![
628                TokenKind::BoolLit(true),
629                TokenKind::BoolLit(false),
630                TokenKind::Eof
631            ]
632        );
633    }
634
635    #[test]
636    fn line_and_block_comments_skipped() {
637        let k = kinds("a // comment\nb /* multi\nline */ c");
638        assert_eq!(k.len(), 4); // a, b, c, eof
639        assert_eq!(k[0], TokenKind::Ident);
640        assert_eq!(k[1], TokenKind::Ident);
641        assert_eq!(k[2], TokenKind::Ident);
642    }
643
644    #[test]
645    fn preprocessor_lines_skipped() {
646        assert_eq!(
647            kinds("#define X 1\nfoo"),
648            vec![TokenKind::Ident, TokenKind::Eof]
649        );
650    }
651
652    #[test]
653    fn compound_operators() {
654        assert_eq!(
655            kinds("+= -= *= /= %= == != <= >= && || ++ -- << >>"),
656            vec![
657                TokenKind::PlusEq,
658                TokenKind::MinusEq,
659                TokenKind::StarEq,
660                TokenKind::SlashEq,
661                TokenKind::PercentEq,
662                TokenKind::EqEq,
663                TokenKind::BangEq,
664                TokenKind::LtEq,
665                TokenKind::GtEq,
666                TokenKind::AmpAmp,
667                TokenKind::BarBar,
668                TokenKind::PlusPlus,
669                TokenKind::MinusMinus,
670                TokenKind::Shl,
671                TokenKind::Shr,
672                TokenKind::Eof,
673            ]
674        );
675    }
676
677    #[test]
678    fn single_operators() {
679        assert_eq!(
680            kinds("+ - * / % = < > & | ^ ~ ! ? : ;"),
681            vec![
682                TokenKind::Plus,
683                TokenKind::Minus,
684                TokenKind::Star,
685                TokenKind::Slash,
686                TokenKind::Percent,
687                TokenKind::Eq,
688                TokenKind::Lt,
689                TokenKind::Gt,
690                TokenKind::Amp,
691                TokenKind::Bar,
692                TokenKind::Caret,
693                TokenKind::Tilde,
694                TokenKind::Bang,
695                TokenKind::Question,
696                TokenKind::Colon,
697                TokenKind::Semi,
698                TokenKind::Eof,
699            ]
700        );
701    }
702
703    #[test]
704    fn punctuation() {
705        assert_eq!(
706            kinds("( ) { } [ ] , ."),
707            vec![
708                TokenKind::LParen,
709                TokenKind::RParen,
710                TokenKind::LBrace,
711                TokenKind::RBrace,
712                TokenKind::LBracket,
713                TokenKind::RBracket,
714                TokenKind::Comma,
715                TokenKind::Dot,
716                TokenKind::Eof,
717            ]
718        );
719    }
720
721    #[test]
722    fn spans_track_line_and_column() {
723        let src = "abc\n  xyz";
724        let toks = tokenize(src).unwrap();
725        assert_eq!(toks.len(), 3);
726        assert_eq!(toks[0].span.line, 1);
727        assert_eq!(toks[0].span.col, 1);
728        assert_eq!(toks[1].span.line, 2);
729        assert_eq!(toks[1].span.col, 3);
730        assert_eq!(
731            &src[toks[0].span.start as usize..toks[0].span.end as usize],
732            "abc"
733        );
734        assert_eq!(
735            &src[toks[1].span.start as usize..toks[1].span.end as usize],
736            "xyz"
737        );
738    }
739
740    #[test]
741    fn unterminated_block_comment_errors() {
742        let err = tokenize("a /* never closed").unwrap_err();
743        assert!(err.message.contains("unterminated block comment"));
744    }
745
746    #[test]
747    fn unexpected_char_errors() {
748        let err = tokenize("a @ b").unwrap_err();
749        assert!(err.message.contains("unexpected character"));
750    }
751
752    #[test]
753    fn md2_sample_warp_body_tokenises_cleanly() {
754        // Realistic snippet from a MD2 warp shader. The lexer must not
755        // choke on `float2`, swizzles, or compound `*=` operators.
756        let src = r#"
757        shader_body {
758            float2 uv2 = uv - 0.5;
759            uv2 *= aspect.xy;
760            float dist = length(uv2);
761            ret += tex2D(sampler_main, uv2) * 0.5;
762        }
763        "#;
764        let toks = tokenize(src).unwrap();
765        // Spot-check: `shader_body` is its own keyword.
766        assert!(
767            toks.iter()
768                .any(|t| matches!(t.kind, TokenKind::Keyword(Keyword::ShaderBody)))
769        );
770        // Compound op `*=` is one token, not two.
771        let star_eqs = toks.iter().filter(|t| t.kind == TokenKind::StarEq).count();
772        assert_eq!(star_eqs, 1);
773        // `0.5` is a float, `1` would be int — sanity check the literal path.
774        assert!(
775            toks.iter()
776                .any(|t| matches!(t.kind, TokenKind::FloatLit(v) if (v - 0.5).abs() < 1e-9))
777        );
778    }
779
780    #[test]
781    fn leading_zero_decimal_parses_as_value() {
782        // HLSL allows `02` as integer 2 (the regex translator strips
783        // leading zeros pre-WGSL; the lexer should accept the raw form).
784        assert_eq!(kinds("02"), vec![TokenKind::IntLit(2), TokenKind::Eof]);
785    }
786
787    #[test]
788    fn matrix_type_kept_as_identifier() {
789        // `float2x2`, `mat3x3` style names stay as `Ident` — parser
790        // disambiguates against context.
791        let toks = tokenize("float2x2 m;").unwrap();
792        assert_eq!(toks[0].kind, TokenKind::Ident);
793        assert_eq!(
794            &"float2x2 m;"[toks[0].span.start as usize..toks[0].span.end as usize],
795            "float2x2"
796        );
797    }
798
799    #[test]
800    fn span_merge_combines_ranges() {
801        let a = Span {
802            start: 5,
803            end: 8,
804            line: 2,
805            col: 3,
806        };
807        let b = Span {
808            start: 12,
809            end: 15,
810            line: 2,
811            col: 10,
812        };
813        let m = a.merge(b);
814        assert_eq!(m.start, 5);
815        assert_eq!(m.end, 15);
816        assert_eq!(m.col, 3);
817    }
818}