onedrop_renderer/
blur_pipeline.rs

1//! Blur-pyramid GPU pipeline.
2//!
3//! Builds the 3-level cumulative Gaussian blur (Blur1 → Blur2 → Blur3) that
4//! MD2 user comp shaders sample via `GetBlur1/2/3`. Each level is the
5//! previous level + one more separable Gaussian pass, so the effective σ
6//! grows ~quadratically: Blur1 is mildly soft, Blur3 is markedly so.
7//!
8//! ## Pipeline shape
9//!
10//! Each blur level renders at half the previous level's resolution:
11//! Blur1 = ½ render, Blur2 = ¼, Blur3 = ⅛. Per level we do two
12//! fragment-shader passes (separable Gaussian):
13//!
14//! 1. **Horizontal**: read from the previous-level texture (Blur0 =
15//!    `render_texture` at full res), write to that level's
16//!    `blurN_scratch_texture` (already at the downsampled
17//!    resolution). The vertex viewport defaults to the destination
18//!    texture's size, so the read happens via the hardware bilinear
19//!    sampler — implicit downsample.
20//! 2. **Vertical**: read from `blurN_scratch_texture`, write to that
21//!    level's `blurN_texture` (same resolution).
22//!
23//! After the third level, the three blur textures hold progressively
24//! softer copies of the warp pass output, ready to be bound into the
25//! comp pass's bind group at bindings 3/4/5. User shaders' UV-space
26//! `GetBlur1/2/3(uv)` sample with the same normalised UV the comp
27//! pass already uses — the smaller texture is transparent to the
28//! caller because the bilinear sampler upscales at sample time.
29//!
30//! The shader itself is fixed at a 9-tap kernel for WGSL-compiler-
31//! friendly unrolling; the per-pass `texel_size` uniform tracks the
32//! source texture's resolution so each tap walks one source-texel
33//! regardless of which level we're at.
34
35use bytemuck::{Pod, Zeroable};
36use wgpu::util::DeviceExt;
37
38/// Sigmas (in texels) for each blur level's H+V pass. Each level reads
39/// from the previous one, so the cumulative σ grows like sqrt(Σ σ²).
40/// The values below give visually distinct levels: Blur1 just-soft,
41/// Blur2 noticeably hazy, Blur3 strong halos — close to what MD2
42/// presets expect.
43const SIGMA_BLUR1: f32 = 2.0;
44const SIGMA_BLUR2: f32 = 3.0;
45const SIGMA_BLUR3: f32 = 4.0;
46
47#[repr(C)]
48#[derive(Debug, Clone, Copy, Pod, Zeroable)]
49struct BlurUniforms {
50    direction: [f32; 2],
51    texel_size: [f32; 2],
52    sigma: f32,
53    _pad0: f32,
54    _pad1: f32,
55    _pad2: f32,
56}
57
58/// Drives the 6 blur passes (3 levels × H+V). One pipeline, six bind
59/// groups — each pass needs its own bind group because the source
60/// texture changes between passes.
61pub struct BlurPipeline {
62    pipeline: wgpu::RenderPipeline,
63    bind_group_layout: wgpu::BindGroupLayout,
64    sampler: wgpu::Sampler,
65    /// One uniform buffer per pass — 6 in total. Cheap; uniform buffers
66    /// are tiny and writing different sigmas to one buffer mid-frame
67    /// would require either dynamic offsets or per-pass write_buffer
68    /// calls with a barrier, neither worth the complexity here.
69    uniform_buffers: [wgpu::Buffer; 6],
70    /// One bind group per pass, holding (uniforms[i], source texture,
71    /// sampler). Recreated when textures rebind (resize).
72    bind_groups: [wgpu::BindGroup; 6],
73}
74
75impl BlurPipeline {
76    /// Build the pipeline against `target_format` and wire bind groups
77    /// for the 6 blur passes. The pyramid is downsampled at each level
78    /// (Blur1 = ½ render, Blur2 = ¼, Blur3 = ⅛); per-level scratches
79    /// match their level's resolution.
80    ///
81    /// The pass order is fixed at construction:
82    /// - pass 0 (H, σ=σ1): src = `render_view`,        dst = `b1_scratch`
83    /// - pass 1 (V, σ=σ1): src = `b1_scratch`,         dst = `blur1_view`
84    /// - pass 2 (H, σ=σ2): src = `blur1_view`,         dst = `b2_scratch`
85    /// - pass 3 (V, σ=σ2): src = `b2_scratch`,         dst = `blur2_view`
86    /// - pass 4 (H, σ=σ3): src = `blur2_view`,         dst = `b3_scratch`
87    /// - pass 5 (V, σ=σ3): src = `b3_scratch`,         dst = `blur3_view`
88    ///
89    /// `texel_size` is set per-pass to `1 / source-resolution`, so each
90    /// Gaussian tap walks one source-texel in UV space regardless of
91    /// which level we're at.
92    #[allow(clippy::too_many_arguments)]
93    pub fn new(
94        device: &wgpu::Device,
95        target_format: wgpu::TextureFormat,
96        width: u32,
97        height: u32,
98        render_view: &wgpu::TextureView,
99        blur1_view: &wgpu::TextureView,
100        blur2_view: &wgpu::TextureView,
101        blur3_view: &wgpu::TextureView,
102        blur1_scratch_view: &wgpu::TextureView,
103        blur2_scratch_view: &wgpu::TextureView,
104        blur3_scratch_view: &wgpu::TextureView,
105    ) -> Self {
106        // `blur3_view` is referenced only as a render target (see `render`),
107        // not as a source, but we still take it in `new` so the caller's
108        // ownership model stays symmetric across the four blur textures.
109        let _ = blur3_view;
110        // Per-source-resolution texel sizes. Pass i reads from
111        // `pass_source_texels[i]`, so the shader's `texel_size`
112        // matches that source's stride in normalised UV.
113        let (w1, h1) = (width.max(2) / 2, height.max(2) / 2);
114        let (w2, h2) = (width.max(4) / 4, height.max(4) / 4);
115        let (w3, h3) = (width.max(8) / 8, height.max(8) / 8);
116        let texel_full = [1.0 / width as f32, 1.0 / height as f32];
117        let texel_b1 = [1.0 / w1 as f32, 1.0 / h1 as f32];
118        let texel_b2 = [1.0 / w2 as f32, 1.0 / h2 as f32];
119        let texel_b3 = [1.0 / w3 as f32, 1.0 / h3 as f32];
120
121        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
122            label: Some("Blur BGL"),
123            entries: &[
124                wgpu::BindGroupLayoutEntry {
125                    binding: 0,
126                    visibility: wgpu::ShaderStages::FRAGMENT,
127                    ty: wgpu::BindingType::Buffer {
128                        ty: wgpu::BufferBindingType::Uniform,
129                        has_dynamic_offset: false,
130                        min_binding_size: None,
131                    },
132                    count: None,
133                },
134                wgpu::BindGroupLayoutEntry {
135                    binding: 1,
136                    visibility: wgpu::ShaderStages::FRAGMENT,
137                    ty: wgpu::BindingType::Texture {
138                        sample_type: wgpu::TextureSampleType::Float { filterable: true },
139                        view_dimension: wgpu::TextureViewDimension::D2,
140                        multisampled: false,
141                    },
142                    count: None,
143                },
144                wgpu::BindGroupLayoutEntry {
145                    binding: 2,
146                    visibility: wgpu::ShaderStages::FRAGMENT,
147                    ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
148                    count: None,
149                },
150            ],
151        });
152
153        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
154            label: Some("Blur Pipeline Layout"),
155            bind_group_layouts: &[Some(&bind_group_layout)],
156            immediate_size: 0,
157        });
158
159        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
160            label: Some("Blur Shader"),
161            source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/blur.wgsl").into()),
162        });
163
164        let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
165            label: Some("Blur Pipeline"),
166            layout: Some(&pipeline_layout),
167            vertex: wgpu::VertexState {
168                module: &shader,
169                entry_point: Some("vs_main"),
170                buffers: &[],
171                compilation_options: Default::default(),
172            },
173            fragment: Some(wgpu::FragmentState {
174                module: &shader,
175                entry_point: Some("fs_main"),
176                targets: &[Some(wgpu::ColorTargetState {
177                    format: target_format,
178                    blend: Some(wgpu::BlendState::REPLACE),
179                    write_mask: wgpu::ColorWrites::ALL,
180                })],
181                compilation_options: Default::default(),
182            }),
183            primitive: wgpu::PrimitiveState {
184                topology: wgpu::PrimitiveTopology::TriangleList,
185                ..Default::default()
186            },
187            depth_stencil: None,
188            multisample: wgpu::MultisampleState::default(),
189            multiview_mask: None,
190            cache: None,
191        });
192
193        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
194            label: Some("Blur Sampler"),
195            address_mode_u: wgpu::AddressMode::ClampToEdge,
196            address_mode_v: wgpu::AddressMode::ClampToEdge,
197            address_mode_w: wgpu::AddressMode::ClampToEdge,
198            mag_filter: wgpu::FilterMode::Linear,
199            min_filter: wgpu::FilterMode::Linear,
200            mipmap_filter: wgpu::MipmapFilterMode::Linear,
201            ..Default::default()
202        });
203
204        // One uniform buffer per pass. Direction alternates H/V; sigma
205        // doubles roughly per blur level. `texel_size` matches the
206        // source texture's resolution at each pass — see the H/V
207        // dispatch order above.
208        let make_uniforms = |dir: [f32; 2], texel: [f32; 2], sigma: f32| BlurUniforms {
209            direction: dir,
210            texel_size: texel,
211            sigma,
212            _pad0: 0.0,
213            _pad1: 0.0,
214            _pad2: 0.0,
215        };
216        let pass_specs: [BlurUniforms; 6] = [
217            // pass 0: H, src = render (full)        → b1_scratch
218            make_uniforms([1.0, 0.0], texel_full, SIGMA_BLUR1),
219            // pass 1: V, src = b1_scratch (½)       → blur1
220            make_uniforms([0.0, 1.0], texel_b1, SIGMA_BLUR1),
221            // pass 2: H, src = blur1 (½)            → b2_scratch
222            make_uniforms([1.0, 0.0], texel_b1, SIGMA_BLUR2),
223            // pass 3: V, src = b2_scratch (¼)       → blur2
224            make_uniforms([0.0, 1.0], texel_b2, SIGMA_BLUR2),
225            // pass 4: H, src = blur2 (¼)            → b3_scratch
226            make_uniforms([1.0, 0.0], texel_b2, SIGMA_BLUR3),
227            // pass 5: V, src = b3_scratch (⅛)       → blur3
228            make_uniforms([0.0, 1.0], texel_b3, SIGMA_BLUR3),
229        ];
230
231        let uniform_buffers: [wgpu::Buffer; 6] = std::array::from_fn(|i| {
232            device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
233                label: Some(&format!("Blur Uniforms #{i}")),
234                contents: bytemuck::bytes_of(&pass_specs[i]),
235                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
236            })
237        });
238
239        let pass_sources: [&wgpu::TextureView; 6] = [
240            render_view,        // pass 0: H from render (full)
241            blur1_scratch_view, // pass 1: V from blur1_scratch (½)
242            blur1_view,         // pass 2: H from Blur1 (½)
243            blur2_scratch_view, // pass 3: V from blur2_scratch (¼)
244            blur2_view,         // pass 4: H from Blur2 (¼)
245            blur3_scratch_view, // pass 5: V from blur3_scratch (⅛)
246        ];
247
248        let bind_groups: [wgpu::BindGroup; 6] = std::array::from_fn(|i| {
249            Self::make_bind_group(
250                device,
251                &bind_group_layout,
252                &uniform_buffers[i],
253                pass_sources[i],
254                &sampler,
255                i,
256            )
257        });
258
259        Self {
260            pipeline,
261            bind_group_layout,
262            sampler,
263            uniform_buffers,
264            bind_groups,
265        }
266    }
267
268    fn make_bind_group(
269        device: &wgpu::Device,
270        layout: &wgpu::BindGroupLayout,
271        uniforms: &wgpu::Buffer,
272        source: &wgpu::TextureView,
273        sampler: &wgpu::Sampler,
274        idx: usize,
275    ) -> wgpu::BindGroup {
276        device.create_bind_group(&wgpu::BindGroupDescriptor {
277            label: Some(&format!("Blur Bind Group #{idx}")),
278            layout,
279            entries: &[
280                wgpu::BindGroupEntry {
281                    binding: 0,
282                    resource: uniforms.as_entire_binding(),
283                },
284                wgpu::BindGroupEntry {
285                    binding: 1,
286                    resource: wgpu::BindingResource::TextureView(source),
287                },
288                wgpu::BindGroupEntry {
289                    binding: 2,
290                    resource: wgpu::BindingResource::Sampler(sampler),
291                },
292            ],
293        })
294    }
295
296    /// Recreate the bind groups + refresh per-pass uniform `texel_size`
297    /// after the underlying textures are invalidated (resize).
298    #[allow(clippy::too_many_arguments)]
299    pub fn rebind(
300        &mut self,
301        device: &wgpu::Device,
302        queue: &wgpu::Queue,
303        width: u32,
304        height: u32,
305        render_view: &wgpu::TextureView,
306        blur1_view: &wgpu::TextureView,
307        blur2_view: &wgpu::TextureView,
308        blur1_scratch_view: &wgpu::TextureView,
309        blur2_scratch_view: &wgpu::TextureView,
310        blur3_scratch_view: &wgpu::TextureView,
311    ) {
312        let (w1, h1) = (width.max(2) / 2, height.max(2) / 2);
313        let (w2, h2) = (width.max(4) / 4, height.max(4) / 4);
314        let (w3, h3) = (width.max(8) / 8, height.max(8) / 8);
315        let texel_full = [1.0 / width as f32, 1.0 / height as f32];
316        let texel_b1 = [1.0 / w1 as f32, 1.0 / h1 as f32];
317        let texel_b2 = [1.0 / w2 as f32, 1.0 / h2 as f32];
318        let texel_b3 = [1.0 / w3 as f32, 1.0 / h3 as f32];
319        // Per-pass (sigma, source-texel, direction) tuples — must
320        // match the order used in `new`.
321        let pass_uniforms: [(f32, [f32; 2], [f32; 2]); 6] = [
322            (SIGMA_BLUR1, texel_full, [1.0, 0.0]),
323            (SIGMA_BLUR1, texel_b1, [0.0, 1.0]),
324            (SIGMA_BLUR2, texel_b1, [1.0, 0.0]),
325            (SIGMA_BLUR2, texel_b2, [0.0, 1.0]),
326            (SIGMA_BLUR3, texel_b2, [1.0, 0.0]),
327            (SIGMA_BLUR3, texel_b3, [0.0, 1.0]),
328        ];
329        for (i, &(sigma, texel, direction)) in pass_uniforms.iter().enumerate() {
330            let u = BlurUniforms {
331                direction,
332                texel_size: texel,
333                sigma,
334                _pad0: 0.0,
335                _pad1: 0.0,
336                _pad2: 0.0,
337            };
338            queue.write_buffer(&self.uniform_buffers[i], 0, bytemuck::bytes_of(&u));
339        }
340
341        let pass_sources: [&wgpu::TextureView; 6] = [
342            render_view,
343            blur1_scratch_view,
344            blur1_view,
345            blur2_scratch_view,
346            blur2_view,
347            blur3_scratch_view,
348        ];
349        for (i, source) in pass_sources.iter().enumerate() {
350            self.bind_groups[i] = Self::make_bind_group(
351                device,
352                &self.bind_group_layout,
353                &self.uniform_buffers[i],
354                source,
355                &self.sampler,
356                i,
357            );
358        }
359    }
360
361    /// Encode the 6 blur passes for this frame. Caller is responsible
362    /// for `submit`-ing the encoder; the pipeline assumes the input
363    /// `render_texture` already holds the current warp pass output.
364    #[allow(clippy::too_many_arguments)]
365    pub fn render(
366        &self,
367        encoder: &mut wgpu::CommandEncoder,
368        blur1_view: &wgpu::TextureView,
369        blur2_view: &wgpu::TextureView,
370        blur3_view: &wgpu::TextureView,
371        blur1_scratch_view: &wgpu::TextureView,
372        blur2_scratch_view: &wgpu::TextureView,
373        blur3_scratch_view: &wgpu::TextureView,
374    ) {
375        let targets: [&wgpu::TextureView; 6] = [
376            blur1_scratch_view, // pass 0 H → b1_scratch (½)
377            blur1_view,         // pass 1 V → blur1 (½)
378            blur2_scratch_view, // pass 2 H → b2_scratch (¼)
379            blur2_view,         // pass 3 V → blur2 (¼)
380            blur3_scratch_view, // pass 4 H → b3_scratch (⅛)
381            blur3_view,         // pass 5 V → blur3 (⅛)
382        ];
383        for (i, target) in targets.iter().enumerate() {
384            let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
385                label: Some(&format!("Blur Pass #{i}")),
386                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
387                    view: target,
388                    depth_slice: None,
389                    resolve_target: None,
390                    ops: wgpu::Operations {
391                        load: wgpu::LoadOp::Clear(wgpu::Color::TRANSPARENT),
392                        store: wgpu::StoreOp::Store,
393                    },
394                })],
395                depth_stencil_attachment: None,
396                timestamp_writes: None,
397                occlusion_query_set: None,
398                multiview_mask: None,
399            });
400            pass.set_pipeline(&self.pipeline);
401            pass.set_bind_group(0, &self.bind_groups[i], &[]);
402            pass.draw(0..3, 0..1);
403        }
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::chain_textures::ChainTextures;
411    use crate::config::{RenderConfig, TextureFormat};
412    use crate::gpu_context::GpuContext;
413
414    #[test]
415    fn blur_pipeline_instantiates() {
416        let cfg = RenderConfig::default();
417        let gpu = pollster::block_on(GpuContext::new(cfg)).unwrap();
418        let chain_tex = ChainTextures::new(&gpu.device, &gpu.config);
419        let _p = BlurPipeline::new(
420            &gpu.device,
421            gpu.config.texture_format.to_wgpu(),
422            gpu.config.width,
423            gpu.config.height,
424            &chain_tex.render_texture_view,
425            &chain_tex.blur1_texture_view,
426            &chain_tex.blur2_texture_view,
427            &chain_tex.blur3_texture_view,
428            &chain_tex.blur1_scratch_texture_view,
429            &chain_tex.blur2_scratch_texture_view,
430            &chain_tex.blur3_scratch_texture_view,
431        );
432    }
433
434    /// Seed `render_texture` with a sharp 50/50 split (one half black, one
435    /// half white) and verify that running the blur passes spreads
436    /// energy across the boundary in `blur1_texture`. A non-blurred
437    /// straight copy would still produce 0 or 255; the Gaussian must
438    /// produce some intermediate value within the kernel reach of the
439    /// boundary.
440    #[test]
441    fn blur_pipeline_spreads_energy_across_boundary() {
442        let cfg = RenderConfig {
443            width: 64,
444            height: 64,
445            texture_format: TextureFormat::Bgra8Unorm,
446            ..Default::default()
447        };
448        let gpu = pollster::block_on(GpuContext::new(cfg)).unwrap();
449        let chain_tex = ChainTextures::new(&gpu.device, &gpu.config);
450        let blur = BlurPipeline::new(
451            &gpu.device,
452            gpu.config.texture_format.to_wgpu(),
453            gpu.config.width,
454            gpu.config.height,
455            &chain_tex.render_texture_view,
456            &chain_tex.blur1_texture_view,
457            &chain_tex.blur2_texture_view,
458            &chain_tex.blur3_texture_view,
459            &chain_tex.blur1_scratch_texture_view,
460            &chain_tex.blur2_scratch_texture_view,
461            &chain_tex.blur3_scratch_texture_view,
462        );
463
464        // Seed: left half black, right half white. BGRA layout.
465        let pixels: Vec<u8> = (0..64)
466            .flat_map(|_y| {
467                (0..64).flat_map(move |x| {
468                    let v: u8 = if x < 32 { 0 } else { 255 };
469                    [v, v, v, 255u8]
470                })
471            })
472            .collect();
473        gpu.queue.write_texture(
474            wgpu::TexelCopyTextureInfo {
475                texture: &chain_tex.render_texture,
476                mip_level: 0,
477                origin: wgpu::Origin3d::ZERO,
478                aspect: wgpu::TextureAspect::All,
479            },
480            &pixels,
481            wgpu::TexelCopyBufferLayout {
482                offset: 0,
483                bytes_per_row: Some(64 * 4),
484                rows_per_image: Some(64),
485            },
486            wgpu::Extent3d {
487                width: 64,
488                height: 64,
489                depth_or_array_layers: 1,
490            },
491        );
492
493        let mut encoder = gpu.device.create_command_encoder(&Default::default());
494        blur.render(
495            &mut encoder,
496            &chain_tex.blur1_texture_view,
497            &chain_tex.blur2_texture_view,
498            &chain_tex.blur3_texture_view,
499            &chain_tex.blur1_scratch_texture_view,
500            &chain_tex.blur2_scratch_texture_view,
501            &chain_tex.blur3_scratch_texture_view,
502        );
503        gpu.queue.submit(std::iter::once(encoder.finish()));
504
505        // Read blur1 back at its native ½-res (32 × 32). Sample
506        // straddling the boundary at the centre row. The Gaussian
507        // (combined with the ½-res downsample) should spread the
508        // black/white split into a meaningful grey band across a
509        // few texels.
510        let blur_w: u32 = 32;
511        let blur_h: u32 = 32;
512        let align = wgpu::COPY_BYTES_PER_ROW_ALIGNMENT;
513        let unpadded_row: u32 = blur_w * 4;
514        let padded_row = unpadded_row.div_ceil(align) * align;
515        let staging = gpu.device.create_buffer(&wgpu::BufferDescriptor {
516            label: Some("Blur Test Staging"),
517            size: (padded_row * blur_h) as u64,
518            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
519            mapped_at_creation: false,
520        });
521        let mut e2 = gpu.device.create_command_encoder(&Default::default());
522        e2.copy_texture_to_buffer(
523            wgpu::TexelCopyTextureInfo {
524                texture: &chain_tex.blur1_texture,
525                mip_level: 0,
526                origin: wgpu::Origin3d::ZERO,
527                aspect: wgpu::TextureAspect::All,
528            },
529            wgpu::TexelCopyBufferInfo {
530                buffer: &staging,
531                layout: wgpu::TexelCopyBufferLayout {
532                    offset: 0,
533                    bytes_per_row: Some(padded_row),
534                    rows_per_image: Some(blur_h),
535                },
536            },
537            wgpu::Extent3d {
538                width: blur_w,
539                height: blur_h,
540                depth_or_array_layers: 1,
541            },
542        );
543        gpu.queue.submit(std::iter::once(e2.finish()));
544
545        let (tx, rx) = std::sync::mpsc::channel();
546        staging.slice(..).map_async(wgpu::MapMode::Read, move |r| {
547            let _ = tx.send(r);
548        });
549        gpu.device.poll(wgpu::PollType::wait_indefinitely()).ok();
550        rx.recv().unwrap().unwrap();
551        let view = staging.slice(..).get_mapped_range();
552
553        // Mid-row, column straddling the boundary in ½-res space
554        // (full-res col 31/32 → ½-res col 15/16).
555        let row = blur_h / 2;
556        let col = 15;
557        let off = (row * padded_row + col * 4) as usize;
558        let b = view[off];
559        // Expect a meaningful intermediate value. Pure-bilinear sampling
560        // wouldn't produce this much spread — only the Gaussian does.
561        assert!(
562            b > 20 && b < 200,
563            "boundary pixel should be partially blurred; got {b}"
564        );
565        drop(view);
566        staging.unmap();
567    }
568}