onedrop_hlsl/
parse.rs

1//! HLSL recursive-descent parser.
2//!
3//! Consumes the [`Token`] stream from [`crate::lex::tokenize`] and produces a
4//! [`TranslationUnit`] AST. Pratt-style operator precedence for expressions,
5//! straightforward recursive descent for statements and top-level items.
6//!
7//! The parser is pragmatic, not exhaustive: it handles the HLSL subset that
8//! MilkDrop 2 user comp shaders actually use. Constructs outside that subset
9//! (templates, attributes, geometry-shader semantics, etc.) produce a
10//! [`ParseError`] — the caller falls back to the existing regex pipeline.
11//!
12//! This parser is not yet wired into [`crate::translate_shader`]. The AST
13//! sits available for a future WGSL emitter, while the existing regex passes
14//! continue to drive the current comp-shader pass-rate on
15//! `test-presets-200/`.
16
17use crate::ast::*;
18use crate::lex::{Keyword, LexError, Span, Token, TokenKind, tokenize};
19use std::fmt;
20
21/// Parse error: human-readable message plus the offending source span.
22#[derive(Debug, Clone, PartialEq)]
23pub struct ParseError {
24    pub message: String,
25    pub span: Span,
26}
27
28impl fmt::Display for ParseError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        write!(
31            f,
32            "parse error at {}:{}: {}",
33            self.span.line, self.span.col, self.message
34        )
35    }
36}
37
38impl std::error::Error for ParseError {}
39
40impl From<LexError> for ParseError {
41    fn from(e: LexError) -> Self {
42        Self {
43            message: e.message,
44            span: Span {
45                start: e.offset,
46                end: e.offset,
47                line: e.line,
48                col: e.col,
49            },
50        }
51    }
52}
53
54/// Top-level entry point: lex then parse a full HLSL source string.
55///
56/// Accepts either a bare statement block (no `shader_body` wrapper) or a
57/// full MD2 comp shader (`<global decls> shader_body { ... }`). Returns the
58/// translation unit AST or a [`ParseError`] at the first unrecoverable
59/// position.
60pub fn parse_hlsl(src: &str) -> Result<TranslationUnit, ParseError> {
61    let tokens = tokenize(src)?;
62    let mut p = Parser::new(&tokens, src);
63    p.parse_translation_unit()
64}
65
66/// Hand-written recursive-descent parser. Holds a borrow on the token slice
67/// and source string; produces AST nodes that own their data (so the
68/// parser's lifetime doesn't leak into downstream consumers).
69struct Parser<'a> {
70    tokens: &'a [Token],
71    src: &'a str,
72    pos: usize,
73}
74
75impl<'a> Parser<'a> {
76    fn new(tokens: &'a [Token], src: &'a str) -> Self {
77        Self {
78            tokens,
79            src,
80            pos: 0,
81        }
82    }
83
84    // -------- token cursor helpers --------
85
86    fn peek(&self) -> &Token {
87        &self.tokens[self.pos]
88    }
89
90    fn peek_at(&self, offset: usize) -> Option<&Token> {
91        self.tokens.get(self.pos + offset)
92    }
93
94    fn bump(&mut self) -> &Token {
95        let t = &self.tokens[self.pos];
96        if !matches!(t.kind, TokenKind::Eof) {
97            self.pos += 1;
98        }
99        t
100    }
101
102    fn at(&self, kind: &TokenKind) -> bool {
103        std::mem::discriminant(&self.peek().kind) == std::mem::discriminant(kind)
104    }
105
106    fn eat(&mut self, kind: &TokenKind) -> bool {
107        if self.at(kind) {
108            self.bump();
109            true
110        } else {
111            false
112        }
113    }
114
115    fn expect(&mut self, kind: TokenKind, ctx: &str) -> Result<(), ParseError> {
116        if std::mem::discriminant(&self.peek().kind) == std::mem::discriminant(&kind) {
117            self.bump();
118            Ok(())
119        } else {
120            Err(ParseError {
121                message: format!("expected {ctx}, found {:?}", self.peek().kind),
122                span: self.peek().span,
123            })
124        }
125    }
126
127    fn ident_text(&self, span: Span) -> &'a str {
128        &self.src[span.start as usize..span.end as usize]
129    }
130
131    /// Consume the current token if it's an `Ident` and return its text.
132    /// Errors with `ctx` in the message otherwise.
133    fn expect_ident(&mut self, ctx: &str) -> Result<(String, Span), ParseError> {
134        let t = self.peek();
135        if matches!(t.kind, TokenKind::Ident) {
136            let span = t.span;
137            let s = self.ident_text(span).to_string();
138            self.bump();
139            Ok((s, span))
140        } else {
141            Err(ParseError {
142                message: format!("expected {ctx}, found {:?}", t.kind),
143                span: t.span,
144            })
145        }
146    }
147
148    // -------- top-level --------
149
150    fn parse_translation_unit(&mut self) -> Result<TranslationUnit, ParseError> {
151        let start_span = self.peek().span;
152        let mut items = Vec::new();
153        let mut shader_body = None;
154
155        while !matches!(self.peek().kind, TokenKind::Eof) {
156            if matches!(self.peek().kind, TokenKind::Keyword(Keyword::ShaderBody)) {
157                let kw_span = self.peek().span;
158                self.bump();
159                self.expect(TokenKind::LBrace, "`{` after `shader_body`")?;
160                shader_body = Some(self.parse_block_after_brace(kw_span)?);
161                break;
162            }
163            // Sampler decl: `sampler[2D|3D] <name>;`
164            if matches!(self.peek().kind, TokenKind::Keyword(Keyword::Sampler)) {
165                items.push(self.parse_sampler_decl()?);
166                continue;
167            }
168            // `[static] [const] <type> <name>[, <name>]... ;` global var.
169            // Detected by leading qualifier OR by `<type-ident> <name-ident>`
170            // shape (when the third token is not `(`, which would mean a
171            // function definition).
172            if matches!(
173                self.peek().kind,
174                TokenKind::Keyword(Keyword::Static) | TokenKind::Keyword(Keyword::Const)
175            ) {
176                self.parse_global_var_decl(&mut items)?;
177                continue;
178            }
179            // Function vs. global-var disambiguation: both start with
180            // `<Ident> <Ident>`. The third token tells us which:
181            // - `(`     → function
182            // - any other → global variable
183            if self.is_function_signature() {
184                items.push(Item::Function(self.parse_function_def()?));
185                continue;
186            }
187            if self.is_global_var_decl() {
188                self.parse_global_var_decl(&mut items)?;
189                continue;
190            }
191            // Anything else at top level is an error.
192            return Err(ParseError {
193                message: format!("unexpected top-level token {:?}", self.peek().kind),
194                span: self.peek().span,
195            });
196        }
197
198        let end_span = self.peek().span;
199        Ok(TranslationUnit {
200            items,
201            shader_body,
202            span: start_span.merge(end_span),
203        })
204    }
205
206    /// Peek 3 tokens to decide if we're looking at a function definition:
207    /// `<type-ident> <name-ident> (`. Doesn't consume.
208    fn is_function_signature(&self) -> bool {
209        let t0 = matches!(
210            self.peek().kind,
211            TokenKind::Ident | TokenKind::Keyword(Keyword::Void)
212        );
213        let t1 = matches!(self.peek_at(1).map(|t| &t.kind), Some(TokenKind::Ident));
214        let t2 = matches!(self.peek_at(2).map(|t| &t.kind), Some(TokenKind::LParen));
215        t0 && t1 && t2
216    }
217
218    /// Top-level global var decl detection: `<type-ident> <name-ident>`
219    /// where the third token is *not* `(`. Doesn't consume.
220    fn is_global_var_decl(&self) -> bool {
221        let t0 = matches!(self.peek().kind, TokenKind::Ident);
222        let t1 = matches!(self.peek_at(1).map(|t| &t.kind), Some(TokenKind::Ident));
223        let t2_not_paren = !matches!(self.peek_at(2).map(|t| &t.kind), Some(TokenKind::LParen));
224        t0 && t1 && t2_not_paren
225    }
226
227    /// Parse `[static] [const] <type> <name>[<array>] [= <init>] [, <name> ...] ;`
228    /// and flatten the comma list into one [`Item::GlobalVar`] per name.
229    /// The qualifier flags propagate to every flattened variable.
230    fn parse_global_var_decl(&mut self, items: &mut Vec<Item>) -> Result<(), ParseError> {
231        let start = self.peek().span;
232        let mut is_static = false;
233        let mut is_const = false;
234        loop {
235            match self.peek().kind {
236                TokenKind::Keyword(Keyword::Static) => {
237                    is_static = true;
238                    self.bump();
239                }
240                TokenKind::Keyword(Keyword::Const) => {
241                    is_const = true;
242                    self.bump();
243                }
244                _ => break,
245            }
246        }
247        let ty = self.parse_type()?;
248        loop {
249            let (name, name_span) = self.expect_ident("global variable name")?;
250            let array_len = if self.eat(&TokenKind::LBracket) {
251                let len = self.parse_expr()?;
252                self.expect(TokenKind::RBracket, "`]` after array length")?;
253                Some(len)
254            } else {
255                None
256            };
257            let init = if self.eat(&TokenKind::Eq) {
258                Some(self.parse_expr()?)
259            } else {
260                None
261            };
262            items.push(Item::GlobalVar(GlobalVar {
263                is_static,
264                is_const,
265                ty: ty.clone(),
266                name,
267                array_len,
268                init,
269                span: start.merge(name_span),
270            }));
271            if !self.eat(&TokenKind::Comma) {
272                break;
273            }
274        }
275        self.expect(TokenKind::Semi, "`;` at end of global declaration")?;
276        Ok(())
277    }
278
279    fn parse_sampler_decl(&mut self) -> Result<Item, ParseError> {
280        let kw_tok = self.bump();
281        let kw_span = kw_tok.span;
282        let kw_text = self.ident_text(kw_span);
283        let tag = match kw_text {
284            "sampler2D" => SamplerTag::Sampler2D,
285            "sampler3D" => SamplerTag::Sampler3D,
286            _ => SamplerTag::Sampler,
287        };
288        let (name, name_span) = self.expect_ident("sampler name")?;
289        // Optional `: register(s0)` or `= sampler_state { ... };` — both are
290        // legal HLSL. We swallow up to the next top-level `;` so the parser
291        // doesn't need to model the state block.
292        while !matches!(self.peek().kind, TokenKind::Semi | TokenKind::Eof) {
293            // Skip a brace-balanced block if we see one.
294            if matches!(self.peek().kind, TokenKind::LBrace) {
295                self.skip_braced_block()?;
296                continue;
297            }
298            self.bump();
299        }
300        self.expect(TokenKind::Semi, "`;` after sampler declaration")?;
301        Ok(Item::SamplerDecl(SamplerDecl {
302            tag,
303            name,
304            span: kw_span.merge(name_span),
305        }))
306    }
307
308    /// `{ ... }` skipper for opaque blocks (sampler_state, struct bodies in
309    /// the rare cases MD2 uses them). Caller positions us at `{`; we end
310    /// after the matching `}`.
311    fn skip_braced_block(&mut self) -> Result<(), ParseError> {
312        let open = self.peek().span;
313        self.expect(TokenKind::LBrace, "`{`")?;
314        let mut depth = 1i32;
315        while depth > 0 {
316            match self.peek().kind {
317                TokenKind::LBrace => {
318                    depth += 1;
319                    self.bump();
320                }
321                TokenKind::RBrace => {
322                    depth -= 1;
323                    self.bump();
324                }
325                TokenKind::Eof => {
326                    return Err(ParseError {
327                        message: "unterminated `{` block".to_string(),
328                        span: open,
329                    });
330                }
331                _ => {
332                    self.bump();
333                }
334            }
335        }
336        Ok(())
337    }
338
339    fn parse_function_def(&mut self) -> Result<FunctionDef, ParseError> {
340        let start = self.peek().span;
341        let return_type = self.parse_type()?;
342        let (name, _) = self.expect_ident("function name")?;
343        self.expect(TokenKind::LParen, "`(` after function name")?;
344        let mut params = Vec::new();
345        if !matches!(self.peek().kind, TokenKind::RParen) {
346            loop {
347                params.push(self.parse_param()?);
348                if !self.eat(&TokenKind::Comma) {
349                    break;
350                }
351            }
352        }
353        self.expect(TokenKind::RParen, "`)` after parameter list")?;
354        let body = self.parse_block()?;
355        let end = body.span;
356        Ok(FunctionDef {
357            return_type,
358            name,
359            params,
360            body,
361            span: start.merge(end),
362        })
363    }
364
365    fn parse_param(&mut self) -> Result<Param, ParseError> {
366        let start = self.peek().span;
367        let qualifier = match self.peek().kind {
368            TokenKind::Keyword(Keyword::In) => {
369                self.bump();
370                Some(ParamQualifier::In)
371            }
372            TokenKind::Keyword(Keyword::Out) => {
373                self.bump();
374                Some(ParamQualifier::Out)
375            }
376            TokenKind::Keyword(Keyword::InOut) => {
377                self.bump();
378                Some(ParamQualifier::InOut)
379            }
380            _ => None,
381        };
382        let ty = self.parse_type()?;
383        let (name, name_span) = self.expect_ident("parameter name")?;
384        // HLSL allows `: SEMANTIC` after a parameter — skip silently.
385        if self.eat(&TokenKind::Colon) {
386            // Eat one identifier (the semantic).
387            if matches!(self.peek().kind, TokenKind::Ident) {
388                self.bump();
389            }
390        }
391        Ok(Param {
392            qualifier,
393            ty,
394            name,
395            span: start.merge(name_span),
396        })
397    }
398
399    /// `<type-ident>` — one identifier consumed. We do not look ahead for
400    /// matrix bracket syntax; HLSL spells matrices as compound idents
401    /// (`float2x2`, `mat3x3`).
402    fn parse_type(&mut self) -> Result<TypeRef, ParseError> {
403        if matches!(self.peek().kind, TokenKind::Keyword(Keyword::Void)) {
404            let span = self.peek().span;
405            self.bump();
406            return Ok(TypeRef::new("void", span));
407        }
408        if matches!(self.peek().kind, TokenKind::Ident) {
409            let span = self.peek().span;
410            let name = self.ident_text(span).to_string();
411            self.bump();
412            return Ok(TypeRef::new(name, span));
413        }
414        Err(ParseError {
415            message: format!("expected type name, found {:?}", self.peek().kind),
416            span: self.peek().span,
417        })
418    }
419
420    // -------- statements --------
421
422    fn parse_block(&mut self) -> Result<Block, ParseError> {
423        let open_span = self.peek().span;
424        self.expect(TokenKind::LBrace, "`{` at start of block")?;
425        self.parse_block_after_brace(open_span)
426    }
427
428    fn parse_block_after_brace(&mut self, open_span: Span) -> Result<Block, ParseError> {
429        let mut stmts = Vec::new();
430        while !matches!(self.peek().kind, TokenKind::RBrace | TokenKind::Eof) {
431            self.parse_stmt_unit(&mut stmts)?;
432        }
433        let close_span = self.peek().span;
434        self.expect(TokenKind::RBrace, "`}` at end of block")?;
435        Ok(Block {
436            stmts,
437            span: open_span.merge(close_span),
438        })
439    }
440
441    /// Parse one *syntactic* statement unit and push the resulting Stmt(s)
442    /// onto `out`. Most kinds push exactly one statement, but two HLSL
443    /// idioms produce multiple:
444    ///
445    /// - Multi-name declarations (`float a, b, c;`) — one Stmt per name.
446    /// - Comma-separated statements (`a += 1, b += 2;`) — one Stmt per
447    ///   comma-separated assignment/expression.
448    ///
449    /// The block parser calls this in a loop so the caller-side AST sees a
450    /// flat statement list — convenient for downstream emitters that don't
451    /// care about source-line grouping.
452    fn parse_stmt_unit(&mut self, out: &mut Vec<Stmt>) -> Result<(), ParseError> {
453        let single = match self.peek().kind {
454            TokenKind::LBrace => Some(Stmt::Block(self.parse_block()?)),
455            TokenKind::Keyword(Keyword::If) => Some(self.parse_if()?),
456            TokenKind::Keyword(Keyword::While) => Some(self.parse_while()?),
457            TokenKind::Keyword(Keyword::Do) => Some(self.parse_do_while()?),
458            TokenKind::Keyword(Keyword::For) => Some(self.parse_for()?),
459            TokenKind::Keyword(Keyword::Return) => Some(self.parse_return()?),
460            TokenKind::Keyword(Keyword::Break) => {
461                self.bump();
462                self.expect(TokenKind::Semi, "`;` after `break`")?;
463                Some(Stmt::Break)
464            }
465            TokenKind::Keyword(Keyword::Continue) => {
466                self.bump();
467                self.expect(TokenKind::Semi, "`;` after `continue`")?;
468                Some(Stmt::Continue)
469            }
470            TokenKind::Semi => {
471                self.bump();
472                None
473            }
474            _ => {
475                self.parse_decl_or_expr_stmts(out)?;
476                return Ok(());
477            }
478        };
479        if let Some(s) = single {
480            out.push(s);
481        }
482        Ok(())
483    }
484
485    /// Compatibility shim — many internal callsites (for-loop init, etc.)
486    /// still want a single Stmt. Calls [`parse_stmt_unit`] and pulls the
487    /// first emitted statement; subsequent statements (rare in those
488    /// contexts) are concatenated as a synthetic [`Stmt::Block`] so no
489    /// declared variables go missing.
490    fn parse_stmt(&mut self) -> Result<Stmt, ParseError> {
491        let mut buf = Vec::with_capacity(1);
492        self.parse_stmt_unit(&mut buf)?;
493        match buf.len() {
494            0 => Ok(Stmt::Block(Block {
495                stmts: Vec::new(),
496                span: self.peek().span,
497            })),
498            1 => Ok(buf.into_iter().next().unwrap()),
499            _ => {
500                let span = buf
501                    .first()
502                    .map(stmt_span)
503                    .unwrap_or(Span::dummy())
504                    .merge(buf.last().map(stmt_span).unwrap_or(Span::dummy()));
505                Ok(Stmt::Block(Block { stmts: buf, span }))
506            }
507        }
508    }
509
510    /// The tricky dispatch: a statement starts with an `Ident`. Decide
511    /// between local decl (`<type> <name>`) and expression statement
512    /// (`<expr>;` or `<lhs> [op]= <rhs>;`).
513    ///
514    /// Heuristic: if the lookahead is `Ident Ident`, it's a decl. We also
515    /// allow the leading qualifier `const` for completeness even though MD2
516    /// rarely uses it locally.
517    fn parse_decl_or_expr_stmts(&mut self, out: &mut Vec<Stmt>) -> Result<(), ParseError> {
518        // Strip leading `const` if present — local `const` decls are legal
519        // in HLSL even though MD2 rarely uses them.
520        let saved_pos = self.pos;
521        let mut had_const = false;
522        if matches!(self.peek().kind, TokenKind::Keyword(Keyword::Const)) {
523            self.bump();
524            had_const = true;
525        }
526
527        if matches!(self.peek().kind, TokenKind::Ident)
528            && matches!(self.peek_at(1).map(|t| &t.kind), Some(TokenKind::Ident))
529        {
530            return self.parse_local_decl_tail(out);
531        }
532
533        // Not a decl — rewind the `const` skip and parse as expression
534        // statement.
535        if had_const {
536            self.pos = saved_pos;
537        }
538        self.parse_expr_or_assign_stmts(out)
539    }
540
541    /// Parse the rest of a local decl after the type+name lookahead
542    /// matched. Eats: `<type> <name>[<dim>] [= <init>] [, <name>...] ;`.
543    /// Pushes one [`Stmt::LocalDecl`] per name onto `out`.
544    fn parse_local_decl_tail(&mut self, out: &mut Vec<Stmt>) -> Result<(), ParseError> {
545        let start = self.peek().span;
546        let ty = self.parse_type()?;
547        loop {
548            let (name, name_span) = self.expect_ident("variable name")?;
549            let mut array_len = None;
550            if self.eat(&TokenKind::LBracket) {
551                array_len = Some(self.parse_expr()?);
552                self.expect(TokenKind::RBracket, "`]` after array length")?;
553            }
554            let init = if self.eat(&TokenKind::Eq) {
555                // Chained-assign init — HLSL accepts `T x = y = z = expr;`,
556                // evaluating right-to-left. parse_expr stops at `=` (not a
557                // binary op), so each `=` we still see after the init expr
558                // is the next link in the chain. We collect the chain and
559                // emit each intermediate `lhs = rhs` as its own assignment
560                // statement *before* the LocalDecl; the leftmost LHS then
561                // initialises the new variable. WGSL has no assignment
562                // expression, so lowering at parse time keeps the rest of
563                // the pipeline free of `var x = (y = ...)` shapes.
564                let mut chain = vec![self.parse_expr()?];
565                while self.eat(&TokenKind::Eq) {
566                    chain.push(self.parse_expr()?);
567                }
568                for i in (0..chain.len().saturating_sub(1)).rev() {
569                    let target = chain[i].clone();
570                    let value = chain[i + 1].clone();
571                    let span = target.span().merge(value.span());
572                    out.push(Stmt::Assign(AssignStmt {
573                        target,
574                        op: AssignOp::Set,
575                        value,
576                        span,
577                    }));
578                }
579                Some(chain.into_iter().next().unwrap())
580            } else {
581                None
582            };
583            out.push(Stmt::LocalDecl(LocalDecl {
584                ty: ty.clone(),
585                name,
586                array_len,
587                init,
588                span: start.merge(name_span),
589            }));
590            if !self.eat(&TokenKind::Comma) {
591                break;
592            }
593        }
594        self.expect(TokenKind::Semi, "`;` at end of declaration")?;
595        Ok(())
596    }
597
598    /// Parse one or more comma-separated `<expr>` or `<lhs> [op]= <rhs>`
599    /// statements, terminated by `;`. MD2 user shaders use the comma
600    /// form (`a += 1, b += 2;`) freely; we flatten into multiple
601    /// statements at parse time so downstream emitters don't need to
602    /// understand the comma operator.
603    fn parse_expr_or_assign_stmts(&mut self, out: &mut Vec<Stmt>) -> Result<(), ParseError> {
604        loop {
605            let start = self.peek().span;
606            let lhs = self.parse_expr()?;
607            if let Some(op) = self.peek_assign_op() {
608                self.bump();
609                let rhs = self.parse_expr()?;
610                let span = start.merge(rhs.span());
611                out.push(Stmt::Assign(AssignStmt {
612                    target: lhs,
613                    op,
614                    value: rhs,
615                    span,
616                }));
617            } else {
618                out.push(Stmt::Expr(lhs));
619            }
620            if !self.eat(&TokenKind::Comma) {
621                break;
622            }
623        }
624        self.expect(TokenKind::Semi, "`;` after statement")?;
625        Ok(())
626    }
627
628    fn peek_assign_op(&self) -> Option<AssignOp> {
629        Some(match self.peek().kind {
630            TokenKind::Eq => AssignOp::Set,
631            TokenKind::PlusEq => AssignOp::Add,
632            TokenKind::MinusEq => AssignOp::Sub,
633            TokenKind::StarEq => AssignOp::Mul,
634            TokenKind::SlashEq => AssignOp::Div,
635            TokenKind::PercentEq => AssignOp::Rem,
636            _ => return None,
637        })
638    }
639
640    fn parse_if(&mut self) -> Result<Stmt, ParseError> {
641        let start = self.peek().span;
642        self.bump(); // if
643        self.expect(TokenKind::LParen, "`(` after `if`")?;
644        let cond = self.parse_expr()?;
645        self.expect(TokenKind::RParen, "`)` after `if` condition")?;
646        let then_branch = Box::new(self.parse_stmt()?);
647        let else_branch = if matches!(self.peek().kind, TokenKind::Keyword(Keyword::Else)) {
648            self.bump();
649            Some(Box::new(self.parse_stmt()?))
650        } else {
651            None
652        };
653        let end_span = else_branch
654            .as_ref()
655            .map(|b| stmt_span(b))
656            .unwrap_or(stmt_span(&then_branch));
657        Ok(Stmt::If(IfStmt {
658            cond,
659            then_branch,
660            else_branch,
661            span: start.merge(end_span),
662        }))
663    }
664
665    fn parse_while(&mut self) -> Result<Stmt, ParseError> {
666        let start = self.peek().span;
667        self.bump(); // while
668        self.expect(TokenKind::LParen, "`(` after `while`")?;
669        let cond = self.parse_expr()?;
670        self.expect(TokenKind::RParen, "`)` after `while` condition")?;
671        let body = Box::new(self.parse_stmt()?);
672        let end = stmt_span(&body);
673        Ok(Stmt::While(WhileStmt {
674            cond,
675            body,
676            do_while: false,
677            span: start.merge(end),
678        }))
679    }
680
681    fn parse_do_while(&mut self) -> Result<Stmt, ParseError> {
682        let start = self.peek().span;
683        self.bump(); // do
684        let body = Box::new(self.parse_stmt()?);
685        self.expect(
686            TokenKind::Keyword(Keyword::While),
687            "`while` after `do` body",
688        )?;
689        self.expect(TokenKind::LParen, "`(` after `while`")?;
690        let cond = self.parse_expr()?;
691        self.expect(TokenKind::RParen, "`)` after `while` condition")?;
692        let end = self.peek().span;
693        self.expect(TokenKind::Semi, "`;` after do-while")?;
694        Ok(Stmt::While(WhileStmt {
695            cond,
696            body,
697            do_while: true,
698            span: start.merge(end),
699        }))
700    }
701
702    fn parse_for(&mut self) -> Result<Stmt, ParseError> {
703        let start = self.peek().span;
704        self.bump(); // for
705        self.expect(TokenKind::LParen, "`(` after `for`")?;
706        let init = if matches!(self.peek().kind, TokenKind::Semi) {
707            self.bump();
708            None
709        } else {
710            // Init position: a decl or an expression statement; both end
711            // with `;`. parse_stmt handles either branch.
712            Some(Box::new(self.parse_stmt()?))
713        };
714        let cond = if matches!(self.peek().kind, TokenKind::Semi) {
715            None
716        } else {
717            Some(self.parse_expr()?)
718        };
719        self.expect(TokenKind::Semi, "`;` between `for` cond and step")?;
720        let step = if matches!(self.peek().kind, TokenKind::RParen) {
721            None
722        } else {
723            // `for` step accepts an assignment-shaped expression
724            // (`i = i + 1`) — captured as an [`Expr::Assign`] node so the
725            // AST keeps a typed handle on the LHS, the operator, and the
726            // RHS. Plain expressions (`i++`) parse through `parse_expr`
727            // without an assignment tail.
728            let base = self.parse_expr()?;
729            if let Some(op) = self.peek_assign_op() {
730                self.bump();
731                let rhs = self.parse_expr()?;
732                let span = base.span().merge(rhs.span());
733                Some(Expr::Assign(AssignExpr {
734                    target: Box::new(base),
735                    op,
736                    value: Box::new(rhs),
737                    span,
738                }))
739            } else {
740                Some(base)
741            }
742        };
743        self.expect(TokenKind::RParen, "`)` to close `for` header")?;
744        let body = Box::new(self.parse_stmt()?);
745        let end = stmt_span(&body);
746        Ok(Stmt::For(ForStmt {
747            init,
748            cond,
749            step,
750            body,
751            span: start.merge(end),
752        }))
753    }
754
755    fn parse_return(&mut self) -> Result<Stmt, ParseError> {
756        self.bump(); // return
757        let val = if matches!(self.peek().kind, TokenKind::Semi) {
758            None
759        } else {
760            Some(self.parse_expr()?)
761        };
762        self.expect(TokenKind::Semi, "`;` after `return`")?;
763        Ok(Stmt::Return(val))
764    }
765
766    // -------- expressions (Pratt) --------
767
768    fn parse_expr(&mut self) -> Result<Expr, ParseError> {
769        self.parse_ternary()
770    }
771
772    fn parse_ternary(&mut self) -> Result<Expr, ParseError> {
773        let cond = self.parse_logical_or()?;
774        if self.eat(&TokenKind::Question) {
775            let then_expr = self.parse_expr()?;
776            self.expect(TokenKind::Colon, "`:` in ternary expression")?;
777            let else_expr = self.parse_expr()?;
778            let span = cond.span().merge(else_expr.span());
779            return Ok(Expr::Ternary(TernaryExpr {
780                cond: Box::new(cond),
781                then_expr: Box::new(then_expr),
782                else_expr: Box::new(else_expr),
783                span,
784            }));
785        }
786        Ok(cond)
787    }
788
789    fn parse_binary_left<F>(
790        &mut self,
791        mut next: F,
792        accept: &[(TokenKind, BinaryOp)],
793    ) -> Result<Expr, ParseError>
794    where
795        F: FnMut(&mut Self) -> Result<Expr, ParseError>,
796    {
797        let mut lhs = next(self)?;
798        while let Some(op) = accept
799            .iter()
800            .find(|(k, _)| std::mem::discriminant(&self.peek().kind) == std::mem::discriminant(k))
801            .map(|(_, op)| *op)
802        {
803            self.bump();
804            let rhs = next(self)?;
805            let span = lhs.span().merge(rhs.span());
806            lhs = Expr::Binary(BinaryExpr {
807                op,
808                lhs: Box::new(lhs),
809                rhs: Box::new(rhs),
810                span,
811            });
812        }
813        Ok(lhs)
814    }
815
816    fn parse_logical_or(&mut self) -> Result<Expr, ParseError> {
817        self.parse_binary_left(
818            |s| s.parse_logical_and(),
819            &[(TokenKind::BarBar, BinaryOp::Or)],
820        )
821    }
822
823    fn parse_logical_and(&mut self) -> Result<Expr, ParseError> {
824        self.parse_binary_left(|s| s.parse_bit_or(), &[(TokenKind::AmpAmp, BinaryOp::And)])
825    }
826
827    fn parse_bit_or(&mut self) -> Result<Expr, ParseError> {
828        self.parse_binary_left(|s| s.parse_bit_xor(), &[(TokenKind::Bar, BinaryOp::BitOr)])
829    }
830
831    fn parse_bit_xor(&mut self) -> Result<Expr, ParseError> {
832        self.parse_binary_left(
833            |s| s.parse_bit_and(),
834            &[(TokenKind::Caret, BinaryOp::BitXor)],
835        )
836    }
837
838    fn parse_bit_and(&mut self) -> Result<Expr, ParseError> {
839        self.parse_binary_left(
840            |s| s.parse_equality(),
841            &[(TokenKind::Amp, BinaryOp::BitAnd)],
842        )
843    }
844
845    fn parse_equality(&mut self) -> Result<Expr, ParseError> {
846        self.parse_binary_left(
847            |s| s.parse_relational(),
848            &[
849                (TokenKind::EqEq, BinaryOp::Eq),
850                (TokenKind::BangEq, BinaryOp::Ne),
851            ],
852        )
853    }
854
855    fn parse_relational(&mut self) -> Result<Expr, ParseError> {
856        self.parse_binary_left(
857            |s| s.parse_shift(),
858            &[
859                (TokenKind::Lt, BinaryOp::Lt),
860                (TokenKind::LtEq, BinaryOp::Le),
861                (TokenKind::Gt, BinaryOp::Gt),
862                (TokenKind::GtEq, BinaryOp::Ge),
863            ],
864        )
865    }
866
867    fn parse_shift(&mut self) -> Result<Expr, ParseError> {
868        self.parse_binary_left(
869            |s| s.parse_additive(),
870            &[
871                (TokenKind::Shl, BinaryOp::Shl),
872                (TokenKind::Shr, BinaryOp::Shr),
873            ],
874        )
875    }
876
877    fn parse_additive(&mut self) -> Result<Expr, ParseError> {
878        self.parse_binary_left(
879            |s| s.parse_multiplicative(),
880            &[
881                (TokenKind::Plus, BinaryOp::Add),
882                (TokenKind::Minus, BinaryOp::Sub),
883            ],
884        )
885    }
886
887    fn parse_multiplicative(&mut self) -> Result<Expr, ParseError> {
888        self.parse_binary_left(
889            |s| s.parse_unary(),
890            &[
891                (TokenKind::Star, BinaryOp::Mul),
892                (TokenKind::Slash, BinaryOp::Div),
893                (TokenKind::Percent, BinaryOp::Rem),
894            ],
895        )
896    }
897
898    fn parse_unary(&mut self) -> Result<Expr, ParseError> {
899        let start = self.peek().span;
900        let op = match self.peek().kind {
901            TokenKind::Minus => Some(UnaryOp::Neg),
902            TokenKind::Plus => Some(UnaryOp::Pos),
903            TokenKind::Bang => Some(UnaryOp::Not),
904            TokenKind::Tilde => Some(UnaryOp::BitNot),
905            _ => None,
906        };
907        if let Some(op) = op {
908            self.bump();
909            let operand = self.parse_unary()?;
910            let span = start.merge(operand.span());
911            return Ok(Expr::Unary(UnaryExpr {
912                op,
913                operand: Box::new(operand),
914                span,
915            }));
916        }
917        self.parse_postfix()
918    }
919
920    fn parse_postfix(&mut self) -> Result<Expr, ParseError> {
921        let mut e = self.parse_primary()?;
922        loop {
923            match self.peek().kind {
924                TokenKind::Dot => {
925                    self.bump();
926                    let (member, m_span) = self.expect_ident("member name after `.`")?;
927                    let span = e.span().merge(m_span);
928                    if is_swizzle(&member) {
929                        e = Expr::Swizzle(SwizzleExpr {
930                            base: Box::new(e),
931                            components: member,
932                            span,
933                        });
934                    } else {
935                        e = Expr::Member(MemberExpr {
936                            base: Box::new(e),
937                            member,
938                            span,
939                        });
940                    }
941                }
942                TokenKind::LBracket => {
943                    self.bump();
944                    let idx = self.parse_expr()?;
945                    let close = self.peek().span;
946                    self.expect(TokenKind::RBracket, "`]` after index")?;
947                    let span = e.span().merge(close);
948                    e = Expr::Index(IndexExpr {
949                        base: Box::new(e),
950                        index: Box::new(idx),
951                        span,
952                    });
953                }
954                TokenKind::PlusPlus | TokenKind::MinusMinus => {
955                    // Postfix increment is a statement in HLSL practice;
956                    // we model it as `x = x + 1` at emit time. For now we
957                    // treat `x++` as just `x` so AST consumers can spot it
958                    // and rewrite — though no emitter does so yet. We bump
959                    // the token to avoid an infinite loop, and emit a
960                    // synthetic binary op to make the value semantics
961                    // close-to-right (`x++` evaluates to `x` pre-increment,
962                    // which we approximate by leaving `e` unchanged).
963                    self.bump();
964                }
965                _ => break,
966            }
967        }
968        Ok(e)
969    }
970
971    fn parse_primary(&mut self) -> Result<Expr, ParseError> {
972        let t = self.peek().clone();
973        match t.kind {
974            TokenKind::IntLit(v) => {
975                self.bump();
976                Ok(Expr::Lit(Lit {
977                    value: LitValue::Int(v),
978                    span: t.span,
979                }))
980            }
981            TokenKind::FloatLit(v) => {
982                self.bump();
983                Ok(Expr::Lit(Lit {
984                    value: LitValue::Float(v),
985                    span: t.span,
986                }))
987            }
988            TokenKind::BoolLit(v) => {
989                self.bump();
990                Ok(Expr::Lit(Lit {
991                    value: LitValue::Bool(v),
992                    span: t.span,
993                }))
994            }
995            TokenKind::Ident => {
996                self.bump();
997                let name = self.ident_text(t.span).to_string();
998                // Call or constructor: `name(...)`.
999                if matches!(self.peek().kind, TokenKind::LParen) {
1000                    self.bump();
1001                    let mut args = Vec::new();
1002                    if !matches!(self.peek().kind, TokenKind::RParen) {
1003                        loop {
1004                            let base = self.parse_expr()?;
1005                            // HLSL accepts an assignment expression as a
1006                            // call argument: `lerp(a, tmp = GetBlur1(uv), b)`
1007                            // sets `tmp` *and* uses the new value as the
1008                            // arg. WGSL has no assignment expression, so a
1009                            // separate AST pass will lift the side-effect
1010                            // out; here we just capture the shape as
1011                            // [`Expr::Assign`] so it round-trips.
1012                            let arg = if let Some(op) = self.peek_assign_op() {
1013                                self.bump();
1014                                let rhs = self.parse_expr()?;
1015                                let span = base.span().merge(rhs.span());
1016                                Expr::Assign(AssignExpr {
1017                                    target: Box::new(base),
1018                                    op,
1019                                    value: Box::new(rhs),
1020                                    span,
1021                                })
1022                            } else {
1023                                base
1024                            };
1025                            args.push(arg);
1026                            if !self.eat(&TokenKind::Comma) {
1027                                break;
1028                            }
1029                        }
1030                    }
1031                    let close = self.peek().span;
1032                    self.expect(TokenKind::RParen, "`)` after call arguments")?;
1033                    return Ok(Expr::Call(CallExpr {
1034                        callee: name,
1035                        args,
1036                        span: t.span.merge(close),
1037                    }));
1038                }
1039                Ok(Expr::Ident(name, t.span))
1040            }
1041            TokenKind::LParen => {
1042                self.bump();
1043                let mut inner = self.parse_expr()?;
1044                // C-style comma operator inside parens: `(a, b, c)` evaluates
1045                // all in order and yields `c`. MD2 occasionally writes
1046                // `texsize.zx*(q3, q3)` (a degenerate case). We honour the
1047                // syntax by keeping only the rightmost expression; the
1048                // earlier ones have no side effects in any preset we've
1049                // surveyed.
1050                while self.eat(&TokenKind::Comma) {
1051                    inner = self.parse_expr()?;
1052                }
1053                let close = self.peek().span;
1054                self.expect(TokenKind::RParen, "`)` to close parenthesised expression")?;
1055                Ok(self.with_span(inner, t.span.merge(close)))
1056            }
1057            TokenKind::LBrace => {
1058                // `{ a, b, c }` initialiser list.
1059                self.bump();
1060                let mut elems = Vec::new();
1061                if !matches!(self.peek().kind, TokenKind::RBrace) {
1062                    loop {
1063                        elems.push(self.parse_expr()?);
1064                        if !self.eat(&TokenKind::Comma) {
1065                            break;
1066                        }
1067                        // Trailing comma OK.
1068                        if matches!(self.peek().kind, TokenKind::RBrace) {
1069                            break;
1070                        }
1071                    }
1072                }
1073                let close = self.peek().span;
1074                self.expect(TokenKind::RBrace, "`}` to close initializer list")?;
1075                Ok(Expr::InitList(InitListExpr {
1076                    elems,
1077                    span: t.span.merge(close),
1078                }))
1079            }
1080            _ => Err(ParseError {
1081                message: format!("expected expression, found {:?}", t.kind),
1082                span: t.span,
1083            }),
1084        }
1085    }
1086
1087    /// Replace the outer span on an expression; used after consuming `(`
1088    /// `expr` `)` so the parenthesised whole gets the wider span.
1089    fn with_span(&self, e: Expr, span: Span) -> Expr {
1090        match e {
1091            Expr::Lit(mut l) => {
1092                l.span = span;
1093                Expr::Lit(l)
1094            }
1095            Expr::Ident(n, _) => Expr::Ident(n, span),
1096            Expr::Binary(mut b) => {
1097                b.span = span;
1098                Expr::Binary(b)
1099            }
1100            Expr::Unary(mut u) => {
1101                u.span = span;
1102                Expr::Unary(u)
1103            }
1104            Expr::Ternary(mut t) => {
1105                t.span = span;
1106                Expr::Ternary(t)
1107            }
1108            Expr::Call(mut c) => {
1109                c.span = span;
1110                Expr::Call(c)
1111            }
1112            Expr::Member(mut m) => {
1113                m.span = span;
1114                Expr::Member(m)
1115            }
1116            Expr::Swizzle(mut s) => {
1117                s.span = span;
1118                Expr::Swizzle(s)
1119            }
1120            Expr::Index(mut i) => {
1121                i.span = span;
1122                Expr::Index(i)
1123            }
1124            Expr::InitList(mut l) => {
1125                l.span = span;
1126                Expr::InitList(l)
1127            }
1128            Expr::Assign(mut a) => {
1129                a.span = span;
1130                Expr::Assign(a)
1131            }
1132        }
1133    }
1134}
1135
1136fn stmt_span(s: &Stmt) -> Span {
1137    match s {
1138        Stmt::LocalDecl(d) => d.span,
1139        Stmt::Assign(a) => a.span,
1140        Stmt::Expr(e) => e.span(),
1141        Stmt::If(i) => i.span,
1142        Stmt::While(w) => w.span,
1143        Stmt::For(f) => f.span,
1144        Stmt::Return(Some(e)) => e.span(),
1145        Stmt::Return(None) => Span::dummy(),
1146        Stmt::Break | Stmt::Continue => Span::dummy(),
1147        Stmt::Block(b) => b.span,
1148    }
1149}
1150
1151/// A `.<chars>` member access is a swizzle iff every char is one of the
1152/// component letters (xyzw / rgba), with length 1–4. HLSL doesn't allow
1153/// mixed xyzw+rgba in one swizzle, but we accept it here — the emitter is
1154/// stricter than the parser.
1155fn is_swizzle(s: &str) -> bool {
1156    if s.is_empty() || s.len() > 4 {
1157        return false;
1158    }
1159    s.bytes()
1160        .all(|b| matches!(b, b'x' | b'y' | b'z' | b'w' | b'r' | b'g' | b'b' | b'a'))
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165    use super::*;
1166
1167    fn parse(src: &str) -> Result<TranslationUnit, ParseError> {
1168        parse_hlsl(src)
1169    }
1170
1171    #[test]
1172    fn empty_input_is_empty_unit() {
1173        let tu = parse("").unwrap();
1174        assert!(tu.items.is_empty());
1175        assert!(tu.shader_body.is_none());
1176    }
1177
1178    #[test]
1179    fn shader_body_only() {
1180        let tu = parse("shader_body { float a = 1.0; }").unwrap();
1181        assert!(tu.items.is_empty());
1182        let body = tu.shader_body.expect("shader_body present");
1183        assert_eq!(body.stmts.len(), 1);
1184        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1185            panic!("expected local decl, got {:?}", body.stmts[0]);
1186        };
1187        assert_eq!(d.name, "a");
1188        assert_eq!(d.ty.name, "float");
1189        assert!(matches!(
1190            d.init.as_ref().unwrap(),
1191            Expr::Lit(l) if matches!(l.value, LitValue::Float(v) if (v - 1.0).abs() < 1e-9)
1192        ));
1193    }
1194
1195    #[test]
1196    fn binary_operator_precedence() {
1197        // `a + b * c` should parse as `a + (b * c)`.
1198        let tu = parse("shader_body { float r = a + b * c; }").unwrap();
1199        let body = tu.shader_body.unwrap();
1200        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1201            panic!("expected decl");
1202        };
1203        let init = d.init.as_ref().unwrap();
1204        let Expr::Binary(top) = init else {
1205            panic!("expected binary at top");
1206        };
1207        assert_eq!(top.op, BinaryOp::Add);
1208        // RHS should be a Mul, LHS should be plain ident.
1209        assert!(matches!(*top.lhs, Expr::Ident(ref n, _) if n == "a"));
1210        let Expr::Binary(rhs) = &*top.rhs else {
1211            panic!("expected mul on rhs");
1212        };
1213        assert_eq!(rhs.op, BinaryOp::Mul);
1214    }
1215
1216    #[test]
1217    fn ternary_parses_right_associatively() {
1218        let tu = parse("shader_body { float r = a ? b : c ? d : e; }").unwrap();
1219        let body = tu.shader_body.unwrap();
1220        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1221            panic!();
1222        };
1223        let init = d.init.as_ref().unwrap();
1224        let Expr::Ternary(t) = init else {
1225            panic!("expected ternary");
1226        };
1227        // Right-associative: else_expr is itself a ternary.
1228        assert!(matches!(*t.else_expr, Expr::Ternary(_)));
1229    }
1230
1231    #[test]
1232    fn function_call_parses() {
1233        let tu = parse("shader_body { float l = length(uv2); }").unwrap();
1234        let body = tu.shader_body.unwrap();
1235        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1236            panic!();
1237        };
1238        let Some(Expr::Call(c)) = d.init.as_ref() else {
1239            panic!("expected call init");
1240        };
1241        assert_eq!(c.callee, "length");
1242        assert_eq!(c.args.len(), 1);
1243    }
1244
1245    #[test]
1246    fn swizzle_vs_member_distinguished() {
1247        let tu = parse("shader_body { float a = uv.x; float2 b = vec.xy; float c = obj.member; }")
1248            .unwrap();
1249        let body = tu.shader_body.unwrap();
1250        let Stmt::LocalDecl(d0) = &body.stmts[0] else {
1251            panic!();
1252        };
1253        assert!(matches!(d0.init.as_ref().unwrap(), Expr::Swizzle(_)));
1254        let Stmt::LocalDecl(d1) = &body.stmts[1] else {
1255            panic!();
1256        };
1257        assert!(matches!(d1.init.as_ref().unwrap(), Expr::Swizzle(_)));
1258        let Stmt::LocalDecl(d2) = &body.stmts[2] else {
1259            panic!();
1260        };
1261        assert!(matches!(d2.init.as_ref().unwrap(), Expr::Member(_)));
1262    }
1263
1264    #[test]
1265    fn assignment_statement_with_compound_op() {
1266        let tu = parse("shader_body { ret += tex2D(sampler_main, uv) * 0.5; }").unwrap();
1267        let body = tu.shader_body.unwrap();
1268        let Stmt::Assign(a) = &body.stmts[0] else {
1269            panic!("expected assignment, got {:?}", body.stmts[0]);
1270        };
1271        assert_eq!(a.op, AssignOp::Add);
1272        assert!(matches!(a.target, Expr::Ident(_, _)));
1273    }
1274
1275    #[test]
1276    fn chained_assign_init_lowered_to_pre_decl_statements() {
1277        // HLSL: `float2 ruv = uv = 0.5;`
1278        // Expect: `uv = 0.5;` emitted first, then `float2 ruv = uv;`.
1279        let tu = parse("shader_body { float2 ruv = uv = 0.5; }").unwrap();
1280        let body = tu.shader_body.unwrap();
1281        assert_eq!(body.stmts.len(), 2);
1282        let Stmt::Assign(a) = &body.stmts[0] else {
1283            panic!("expected assign first, got {:?}", body.stmts[0]);
1284        };
1285        assert_eq!(a.op, AssignOp::Set);
1286        assert!(matches!(&a.target, Expr::Ident(n, _) if n == "uv"));
1287        let Stmt::LocalDecl(d) = &body.stmts[1] else {
1288            panic!("expected decl second, got {:?}", body.stmts[1]);
1289        };
1290        assert_eq!(d.name, "ruv");
1291        assert!(matches!(d.init.as_ref().unwrap(), Expr::Ident(n, _) if n == "uv"));
1292    }
1293
1294    #[test]
1295    fn chained_assign_init_three_deep() {
1296        // `float2 ruv = uv = uv2 = expr;` should emit two assigns then a decl.
1297        let tu = parse("shader_body { float2 ruv = uv = uv2 = 0.5; }").unwrap();
1298        let body = tu.shader_body.unwrap();
1299        assert_eq!(body.stmts.len(), 3);
1300        // Innermost first: uv2 = 0.5
1301        let Stmt::Assign(a0) = &body.stmts[0] else {
1302            panic!("expected assign, got {:?}", body.stmts[0]);
1303        };
1304        assert!(matches!(&a0.target, Expr::Ident(n, _) if n == "uv2"));
1305        // Next: uv = uv2
1306        let Stmt::Assign(a1) = &body.stmts[1] else {
1307            panic!("expected assign, got {:?}", body.stmts[1]);
1308        };
1309        assert!(matches!(&a1.target, Expr::Ident(n, _) if n == "uv"));
1310        assert!(matches!(&a1.value, Expr::Ident(n, _) if n == "uv2"));
1311        // Then the decl, init = uv
1312        let Stmt::LocalDecl(d) = &body.stmts[2] else {
1313            panic!("expected decl, got {:?}", body.stmts[2]);
1314        };
1315        assert!(matches!(d.init.as_ref().unwrap(), Expr::Ident(n, _) if n == "uv"));
1316    }
1317
1318    #[test]
1319    fn array_initialiser_local_decl() {
1320        let tu = parse("shader_body { float arr[3] = {1.0, 2.0, 3.0}; }").unwrap();
1321        let body = tu.shader_body.unwrap();
1322        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1323            panic!();
1324        };
1325        assert!(d.array_len.is_some());
1326        let Some(Expr::InitList(l)) = d.init.as_ref() else {
1327            panic!("expected init list");
1328        };
1329        assert_eq!(l.elems.len(), 3);
1330    }
1331
1332    #[test]
1333    fn if_else_chain() {
1334        let tu =
1335            parse("shader_body { if (a > 0) b = 1; else if (a < 0) b = -1; else b = 0; }").unwrap();
1336        let body = tu.shader_body.unwrap();
1337        let Stmt::If(top) = &body.stmts[0] else {
1338            panic!("expected if, got {:?}", body.stmts[0]);
1339        };
1340        assert!(top.else_branch.is_some());
1341        // The nested else should itself be an If.
1342        let Some(else_b) = &top.else_branch else {
1343            panic!()
1344        };
1345        assert!(matches!(else_b.as_ref(), Stmt::If(_)));
1346    }
1347
1348    #[test]
1349    fn while_loop() {
1350        let tu = parse("shader_body { while (n < 4) { ret += 1.0; n = n + 1; } }").unwrap();
1351        let body = tu.shader_body.unwrap();
1352        assert!(matches!(body.stmts[0], Stmt::While(_)));
1353    }
1354
1355    #[test]
1356    fn for_loop_with_init_and_step() {
1357        let tu =
1358            parse("shader_body { for (int i = 0; i < 5; i = i + 1) { ret += 1.0; } }").unwrap();
1359        let body = tu.shader_body.unwrap();
1360        let Stmt::For(f) = &body.stmts[0] else {
1361            panic!("expected for, got {:?}", body.stmts[0]);
1362        };
1363        assert!(f.init.is_some());
1364        assert!(f.cond.is_some());
1365        assert!(f.step.is_some());
1366    }
1367
1368    #[test]
1369    fn function_definition() {
1370        let src = "float square(float x) { return x * x; }";
1371        let tu = parse(src).unwrap();
1372        assert_eq!(tu.items.len(), 1);
1373        let Item::Function(f) = &tu.items[0] else {
1374            panic!("expected function, got {:?}", tu.items[0]);
1375        };
1376        assert_eq!(f.name, "square");
1377        assert_eq!(f.params.len(), 1);
1378        assert_eq!(f.params[0].name, "x");
1379        assert_eq!(f.return_type.name, "float");
1380    }
1381
1382    #[test]
1383    fn sampler_declaration_at_top_level() {
1384        let tu = parse("sampler sampler_fw_sky; shader_body { ret = 0; }").unwrap();
1385        assert_eq!(tu.items.len(), 1);
1386        let Item::SamplerDecl(d) = &tu.items[0] else {
1387            panic!("expected sampler decl");
1388        };
1389        assert_eq!(d.name, "sampler_fw_sky");
1390    }
1391
1392    #[test]
1393    fn nested_constructor_call() {
1394        let tu = parse("shader_body { float2 v = float2(1.0, 2.0); }").unwrap();
1395        let body = tu.shader_body.unwrap();
1396        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1397            panic!();
1398        };
1399        let Some(Expr::Call(c)) = d.init.as_ref() else {
1400            panic!("expected call init");
1401        };
1402        assert_eq!(c.callee, "float2");
1403        assert_eq!(c.args.len(), 2);
1404    }
1405
1406    #[test]
1407    fn unary_negate() {
1408        let tu = parse("shader_body { float a = -b; }").unwrap();
1409        let body = tu.shader_body.unwrap();
1410        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1411            panic!();
1412        };
1413        let Some(Expr::Unary(u)) = d.init.as_ref() else {
1414            panic!("expected unary, got {:?}", d.init);
1415        };
1416        assert_eq!(u.op, UnaryOp::Neg);
1417    }
1418
1419    #[test]
1420    fn realistic_comp_shader_body() {
1421        let src = r#"
1422            sampler sampler_fw_sky;
1423            shader_body {
1424                float2 uvm = frac(uv + time * 0.0);
1425                ret = tex2D(sampler_main, uvm);
1426                float diff = 1 - length(ret.xy - GetBlur1(uvm).xy) * 3.5;
1427                ret.xy *= diff;
1428            }
1429        "#;
1430        let tu = parse(src).unwrap();
1431        assert_eq!(tu.items.len(), 1);
1432        let body = tu.shader_body.expect("body");
1433        assert_eq!(body.stmts.len(), 4);
1434        // Last statement is an `*=` assignment to a swizzle target.
1435        let Stmt::Assign(a) = &body.stmts[3] else {
1436            panic!("expected assign, got {:?}", body.stmts[3]);
1437        };
1438        assert_eq!(a.op, AssignOp::Mul);
1439        assert!(matches!(a.target, Expr::Swizzle(_)));
1440    }
1441
1442    #[test]
1443    fn parse_error_carries_position() {
1444        // Missing `;` after assignment — should fail at the `}` token.
1445        let err = parse("shader_body { a = 1 }").unwrap_err();
1446        assert!(err.message.contains("`;`"));
1447        // Position is at the offending token, not byte 0.
1448        assert!(err.span.start > 0);
1449    }
1450
1451    #[test]
1452    fn realistic_warp_shader_body() {
1453        // MD2 warp shaders use `ret`/`tex2D`/`q1`/local floats heavily.
1454        let src = r#"
1455            shader_body {
1456                float2 uv2 = uv - 0.5;
1457                uv2 *= aspect.xy;
1458                float r = length(uv2);
1459                float ang = atan2(uv2.y, uv2.x);
1460                uv2 = float2(r * cos(ang), r * sin(ang));
1461                ret = tex2D(sampler_main, uv2 + 0.5);
1462            }
1463        "#;
1464        let tu = parse(src).unwrap();
1465        let body = tu.shader_body.unwrap();
1466        // 6 statements expected.
1467        assert_eq!(body.stmts.len(), 6);
1468    }
1469
1470    #[test]
1471    fn lifted_helper_function_before_body() {
1472        let src = r#"
1473            float square(float x) { return x * x; }
1474            float2 normish(float2 v) { return v * (1.0 / length(v)); }
1475            shader_body {
1476                float2 d = normish(uv - 0.5);
1477                ret = float4(square(d.x), square(d.y), 0, 1);
1478            }
1479        "#;
1480        let tu = parse(src).unwrap();
1481        assert_eq!(tu.items.len(), 2);
1482        for item in &tu.items {
1483            assert!(matches!(item, Item::Function(_)));
1484        }
1485        let body = tu.shader_body.unwrap();
1486        assert_eq!(body.stmts.len(), 2);
1487    }
1488
1489    #[test]
1490    fn global_multi_name_decl_flattens_to_one_item_per_name() {
1491        // MD2 idiom: `float3 ret1, neu, blur;` before `shader_body`.
1492        // Parser flattens to three GlobalVar items so downstream emitters
1493        // see one decl per name.
1494        let tu = parse("float3 ret1, neu, blur; shader_body { ret = 0; }").unwrap();
1495        assert_eq!(tu.items.len(), 3);
1496        let names: Vec<&str> = tu
1497            .items
1498            .iter()
1499            .map(|i| match i {
1500                Item::GlobalVar(g) => g.name.as_str(),
1501                _ => panic!("expected GlobalVar"),
1502            })
1503            .collect();
1504        assert_eq!(names, vec!["ret1", "neu", "blur"]);
1505    }
1506
1507    #[test]
1508    fn local_multi_name_decl_flattens_to_one_stmt_per_name() {
1509        let tu = parse("shader_body { float a, b = 2.0, c; }").unwrap();
1510        let body = tu.shader_body.unwrap();
1511        assert_eq!(body.stmts.len(), 3);
1512        for s in &body.stmts {
1513            assert!(matches!(s, Stmt::LocalDecl(_)));
1514        }
1515        let Stmt::LocalDecl(b) = &body.stmts[1] else {
1516            panic!();
1517        };
1518        assert_eq!(b.name, "b");
1519        assert!(b.init.is_some());
1520    }
1521
1522    #[test]
1523    fn comma_separated_assignments_flatten_into_multiple_stmts() {
1524        // `a += 1, b += 2;` parses as two Stmt::Assign nodes.
1525        let tu = parse("shader_body { a += 1, b += 2; }").unwrap();
1526        let body = tu.shader_body.unwrap();
1527        assert_eq!(body.stmts.len(), 2);
1528        for s in &body.stmts {
1529            assert!(matches!(s, Stmt::Assign(_)));
1530        }
1531    }
1532
1533    #[test]
1534    fn global_const_array_initialiser() {
1535        let src = r#"
1536            const float4 samples[5] = { float4(1,0,0,0), float4(0,1,0,0),
1537                                        float4(0,0,1,0), float4(0,0,0,1),
1538                                        float4(1,1,1,1) };
1539            shader_body { ret = samples[0]; }
1540        "#;
1541        let tu = parse(src).unwrap();
1542        assert_eq!(tu.items.len(), 1);
1543        let Item::GlobalVar(g) = &tu.items[0] else {
1544            panic!("expected GlobalVar");
1545        };
1546        assert!(g.is_const);
1547        assert_eq!(g.name, "samples");
1548        assert!(g.array_len.is_some());
1549        let init = g.init.as_ref().expect("init list present");
1550        assert!(matches!(init, Expr::InitList(l) if l.elems.len() == 5));
1551    }
1552
1553    #[test]
1554    fn comma_operator_inside_parens_keeps_rightmost() {
1555        // `(a, b)` evaluates to `b`. Real preset (`orb - inferno`) writes
1556        // `(q3, q3)` to confuse the parser; we accept it and keep the
1557        // tail expression.
1558        let tu = parse("shader_body { float r = (1.0, 2.0); }").unwrap();
1559        let body = tu.shader_body.unwrap();
1560        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1561            panic!();
1562        };
1563        let Some(Expr::Lit(l)) = d.init.as_ref() else {
1564            panic!("expected float literal init");
1565        };
1566        assert!(matches!(l.value, LitValue::Float(v) if (v - 2.0).abs() < 1e-9));
1567    }
1568
1569    #[test]
1570    fn float_array_initialiser_with_int_literals() {
1571        // `float arr[5] = {1,2,3,4,5};` was a known blocker in the residual
1572        // failure list. The parser handles it; the emitter will catch up.
1573        let tu = parse("shader_body { float arr[5] = {1, 2, 3, 4, 5}; }").unwrap();
1574        let body = tu.shader_body.unwrap();
1575        let Stmt::LocalDecl(d) = &body.stmts[0] else {
1576            panic!();
1577        };
1578        let Some(Expr::InitList(l)) = d.init.as_ref() else {
1579            panic!("expected init list");
1580        };
1581        assert_eq!(l.elems.len(), 5);
1582        // All elements are int literals — emitter will need to coerce to
1583        // float at WGSL emit time.
1584        for e in &l.elems {
1585            assert!(matches!(
1586                e,
1587                Expr::Lit(Lit {
1588                    value: LitValue::Int(_),
1589                    ..
1590                })
1591            ));
1592        }
1593    }
1594}