onedrop_hlsl/
ast.rs

1//! HLSL AST.
2//!
3//! Minimal AST that covers what MD2 user comp shaders use. Not a faithful
4//! HLSL AST — only the constructs the regex pipeline already targets, plus
5//! the ones blocking the long-tail residual failures.
6//!
7//! A future WGSL emitter will consume this AST and gradually retire the
8//! string-driven rewrites in [`crate::translate_shader`]. Currently the AST
9//! exists, can be constructed by [`crate::parse`], and is exercised by
10//! tests — but no emitter consumes it yet.
11
12use crate::lex::Span;
13
14/// A whole HLSL translation unit: top-level items (functions, samplers,
15/// `static const` declarations) plus an optional `shader_body { ... }` block
16/// at the end. MD2 always wraps the main code in `shader_body`; standalone
17/// fragments (test fixtures, function-only inputs) omit it.
18#[derive(Debug, Clone, PartialEq)]
19pub struct TranslationUnit {
20    pub items: Vec<Item>,
21    /// `shader_body { ... }` — its statements, or `None` if absent.
22    pub shader_body: Option<Block>,
23    pub span: Span,
24}
25
26/// Top-level item in the translation unit.
27#[derive(Debug, Clone, PartialEq)]
28pub enum Item {
29    /// `<type> <name>(<params>) { <body> }`
30    Function(FunctionDef),
31    /// `sampler <name>;` or `sampler2D <name>;` or
32    /// `sampler <name> = sampler_state { ... };` (the state block is
33    /// kept as opaque text — MD2 never inspects it).
34    SamplerDecl(SamplerDecl),
35    /// `[static] [const] <type> <name>[<array>] [= <expr>];` — file-scope
36    /// variable. The qualifier flags carry whether the original source
37    /// said `static`/`const` so emitters can preserve immutability hints
38    /// when round-tripping. The regex pipeline collapses everything into
39    /// `var NAME: TYPE` at WGSL emit time today; the AST is richer so a
40    /// future emitter has more to work with.
41    ///
42    /// Multi-name declarations (`float3 a, b, c;`) flatten at parse time
43    /// into multiple `GlobalVar` items, one per name.
44    GlobalVar(GlobalVar),
45}
46
47/// `<type> <name>(<params>) { <body> }`
48#[derive(Debug, Clone, PartialEq)]
49pub struct FunctionDef {
50    pub return_type: TypeRef,
51    pub name: String,
52    pub params: Vec<Param>,
53    pub body: Block,
54    pub span: Span,
55}
56
57/// One function parameter: optional storage class qualifier, type, name.
58#[derive(Debug, Clone, PartialEq)]
59pub struct Param {
60    pub qualifier: Option<ParamQualifier>,
61    pub ty: TypeRef,
62    pub name: String,
63    pub span: Span,
64}
65
66/// HLSL `in`, `out`, `inout` parameter direction.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum ParamQualifier {
69    In,
70    Out,
71    InOut,
72}
73
74/// `sampler <name>;` — minimal sampler declaration. MD2 user shaders use
75/// `sampler sampler_<name>` for disk-loaded textures and `sampler_main` /
76/// `sampler_blur1..3` for built-ins.
77#[derive(Debug, Clone, PartialEq)]
78pub struct SamplerDecl {
79    /// Tag at the start of the decl: `sampler`, `sampler2D`, or `sampler3D`.
80    pub tag: SamplerTag,
81    pub name: String,
82    pub span: Span,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum SamplerTag {
87    Sampler,
88    Sampler2D,
89    Sampler3D,
90}
91
92#[derive(Debug, Clone, PartialEq)]
93pub struct GlobalVar {
94    /// `true` if the source said `static`.
95    pub is_static: bool,
96    /// `true` if the source said `const`.
97    pub is_const: bool,
98    pub ty: TypeRef,
99    pub name: String,
100    /// Set when the declaration carried an `[N]` array suffix.
101    pub array_len: Option<Expr>,
102    /// Set when an initializer was provided. `static const` decls keep
103    /// the init mandatory by HLSL rules, but we don't enforce that
104    /// here — the parser accepts whatever the source says, and a
105    /// future validation pass can flag missing inits if needed.
106    pub init: Option<Expr>,
107    pub span: Span,
108}
109
110/// Reference to a (possibly compound) HLSL type — `float`, `float2`,
111/// `float2x2`, `int`, `bool`, `void`, or a struct name. Kept as a raw
112/// identifier plus span so the emitter can map to WGSL on its own table.
113#[derive(Debug, Clone, PartialEq)]
114pub struct TypeRef {
115    pub name: String,
116    pub span: Span,
117}
118
119impl TypeRef {
120    pub fn new(name: impl Into<String>, span: Span) -> Self {
121        Self {
122            name: name.into(),
123            span,
124        }
125    }
126}
127
128/// `{ stmt; stmt; ... }` — a sequence of statements with its own scope.
129#[derive(Debug, Clone, PartialEq)]
130pub struct Block {
131    pub stmts: Vec<Stmt>,
132    pub span: Span,
133}
134
135/// One statement. Statements that contain expressions own the `Expr`
136/// directly — no `Rc`/`Box` indirection at the top level.
137#[derive(Debug, Clone, PartialEq)]
138pub enum Stmt {
139    /// `<type> <name> [= <init>];` or `<type> <name>[<n>];`
140    LocalDecl(LocalDecl),
141    /// `<lhs> [op]= <rhs>;` with the same `=`/`+=`/`-=`/`*=`/`/=`/`%=` set
142    /// the lexer recognises.
143    Assign(AssignStmt),
144    /// `<expr>;` — statement form of an expression (typically a call).
145    Expr(Expr),
146    /// `if (<cond>) <then> [else <else>]`
147    If(IfStmt),
148    /// `while (<cond>) <body>` and `do <body> while (<cond>);`
149    While(WhileStmt),
150    /// `for (<init>; <cond>; <step>) <body>`
151    For(ForStmt),
152    /// `return [<expr>];`
153    Return(Option<Expr>),
154    /// `break;`
155    Break,
156    /// `continue;`
157    Continue,
158    /// `{ ... }`
159    Block(Block),
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct LocalDecl {
164    pub ty: TypeRef,
165    pub name: String,
166    /// `float arr[4] = {1,2,3,4};` — present when the declaration uses
167    /// `[<N>]` array suffix. `None` for scalar/vector decls.
168    pub array_len: Option<Expr>,
169    pub init: Option<Expr>,
170    pub span: Span,
171}
172
173#[derive(Debug, Clone, PartialEq)]
174pub struct AssignStmt {
175    pub target: Expr,
176    pub op: AssignOp,
177    pub value: Expr,
178    pub span: Span,
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum AssignOp {
183    Set, // =
184    Add, // +=
185    Sub, // -=
186    Mul, // *=
187    Div, // /=
188    Rem, // %=
189}
190
191#[derive(Debug, Clone, PartialEq)]
192pub struct IfStmt {
193    pub cond: Expr,
194    pub then_branch: Box<Stmt>,
195    pub else_branch: Option<Box<Stmt>>,
196    pub span: Span,
197}
198
199#[derive(Debug, Clone, PartialEq)]
200pub struct WhileStmt {
201    pub cond: Expr,
202    pub body: Box<Stmt>,
203    /// `true` for `do { } while (...);` — the body runs at least once.
204    pub do_while: bool,
205    pub span: Span,
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub struct ForStmt {
210    pub init: Option<Box<Stmt>>,
211    pub cond: Option<Expr>,
212    pub step: Option<Expr>,
213    pub body: Box<Stmt>,
214    pub span: Span,
215}
216
217/// HLSL expression node. Members and swizzles are split because WGSL
218/// treats them slightly differently — keeping them separate at the AST
219/// level lets the emitter choose the right WGSL form.
220#[derive(Debug, Clone, PartialEq)]
221pub enum Expr {
222    /// `42`, `1.5`, `true`
223    Lit(Lit),
224    /// `foo`, `aspect`, `texsize`
225    Ident(String, Span),
226    /// `lhs + rhs`, `lhs * rhs`, `lhs && rhs`, …
227    Binary(BinaryExpr),
228    /// `-x`, `!x`, `+x`, `~x`
229    Unary(UnaryExpr),
230    /// `cond ? then : else_`
231    Ternary(TernaryExpr),
232    /// `f(arg, arg, …)` — also covers vector/matrix constructors like
233    /// `float2(x, y)`, since they share syntax.
234    Call(CallExpr),
235    /// `obj.member` — struct field access. Swizzles use [`Expr::Swizzle`].
236    Member(MemberExpr),
237    /// `vec.xy`, `color.rgba`, `m.r` — sub-vector / component access.
238    /// HLSL spells RGBA swizzles equivalently to XYZW.
239    Swizzle(SwizzleExpr),
240    /// `arr[i]`
241    Index(IndexExpr),
242    /// `{ 1, 2, 3 }` — array initialiser used in `float arr[3] = {…};`
243    /// and matrix initialisers. Kept distinct from `Call` so the emitter
244    /// can target WGSL's `array<f32, 3>(…)` form.
245    InitList(InitListExpr),
246    /// Assignment as an expression — `i = i + 1`, `x += 2`. HLSL allows
247    /// these inside `for` headers and other expression contexts. Statement
248    /// form (`<stmt>;`) is modelled separately as [`AssignStmt`] to keep
249    /// the common case (top-level statement) cheap to walk.
250    Assign(AssignExpr),
251}
252
253#[derive(Debug, Clone, PartialEq)]
254pub struct AssignExpr {
255    pub target: Box<Expr>,
256    pub op: AssignOp,
257    pub value: Box<Expr>,
258    pub span: Span,
259}
260
261impl Expr {
262    pub fn span(&self) -> Span {
263        match self {
264            Expr::Lit(l) => l.span,
265            Expr::Ident(_, s) => *s,
266            Expr::Binary(b) => b.span,
267            Expr::Unary(u) => u.span,
268            Expr::Ternary(t) => t.span,
269            Expr::Call(c) => c.span,
270            Expr::Member(m) => m.span,
271            Expr::Swizzle(s) => s.span,
272            Expr::Index(i) => i.span,
273            Expr::InitList(l) => l.span,
274            Expr::Assign(a) => a.span,
275        }
276    }
277}
278
279#[derive(Debug, Clone, PartialEq)]
280pub struct Lit {
281    pub value: LitValue,
282    pub span: Span,
283}
284
285#[derive(Debug, Clone, PartialEq)]
286pub enum LitValue {
287    Int(i64),
288    Float(f64),
289    Bool(bool),
290}
291
292#[derive(Debug, Clone, PartialEq)]
293pub struct BinaryExpr {
294    pub op: BinaryOp,
295    pub lhs: Box<Expr>,
296    pub rhs: Box<Expr>,
297    pub span: Span,
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq)]
301pub enum BinaryOp {
302    Add,
303    Sub,
304    Mul,
305    Div,
306    Rem,
307    Eq,
308    Ne,
309    Lt,
310    Le,
311    Gt,
312    Ge,
313    And,
314    Or, // && ||
315    BitAnd,
316    BitOr,
317    BitXor,
318    Shl,
319    Shr,
320}
321
322#[derive(Debug, Clone, PartialEq)]
323pub struct UnaryExpr {
324    pub op: UnaryOp,
325    pub operand: Box<Expr>,
326    pub span: Span,
327}
328
329#[derive(Debug, Clone, Copy, PartialEq, Eq)]
330pub enum UnaryOp {
331    Neg,    // -x
332    Pos,    // +x  (stripped by emitter; carried for fidelity)
333    Not,    // !x
334    BitNot, // ~x
335}
336
337#[derive(Debug, Clone, PartialEq)]
338pub struct TernaryExpr {
339    pub cond: Box<Expr>,
340    pub then_expr: Box<Expr>,
341    pub else_expr: Box<Expr>,
342    pub span: Span,
343}
344
345#[derive(Debug, Clone, PartialEq)]
346pub struct CallExpr {
347    pub callee: String,
348    pub args: Vec<Expr>,
349    pub span: Span,
350}
351
352#[derive(Debug, Clone, PartialEq)]
353pub struct MemberExpr {
354    pub base: Box<Expr>,
355    pub member: String,
356    pub span: Span,
357}
358
359#[derive(Debug, Clone, PartialEq)]
360pub struct SwizzleExpr {
361    pub base: Box<Expr>,
362    /// Original lexeme: `xy`, `rgba`, `zyx`, etc. Length 1–4. Caller is
363    /// responsible for normalising rgba→xyzw at emit time.
364    pub components: String,
365    pub span: Span,
366}
367
368#[derive(Debug, Clone, PartialEq)]
369pub struct IndexExpr {
370    pub base: Box<Expr>,
371    pub index: Box<Expr>,
372    pub span: Span,
373}
374
375#[derive(Debug, Clone, PartialEq)]
376pub struct InitListExpr {
377    pub elems: Vec<Expr>,
378    pub span: Span,
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn expr_span_dispatch() {
387        let s = Span {
388            start: 5,
389            end: 10,
390            line: 1,
391            col: 6,
392        };
393        let e = Expr::Ident("foo".to_string(), s);
394        assert_eq!(e.span(), s);
395    }
396
397    #[test]
398    fn lit_values_round_trip() {
399        let i = LitValue::Int(42);
400        let f = LitValue::Float(1.5);
401        let b = LitValue::Bool(true);
402        assert!(matches!(i, LitValue::Int(42)));
403        assert!(matches!(f, LitValue::Float(v) if (v - 1.5).abs() < 1e-9));
404        assert!(matches!(b, LitValue::Bool(true)));
405    }
406
407    #[test]
408    fn translation_unit_can_be_empty() {
409        let tu = TranslationUnit {
410            items: vec![],
411            shader_body: None,
412            span: Span::dummy(),
413        };
414        assert!(tu.items.is_empty());
415        assert!(tu.shader_body.is_none());
416    }
417}