1use bytemuck::{Pod, Zeroable};
27use wgpu::util::DeviceExt;
28
29use crate::config::MotionVectorParams;
30use crate::warp_pipeline::WarpVertex;
31
32pub const MAX_MOTION_VECTOR_GRID_X: u32 = 64;
36pub const MAX_MOTION_VECTOR_GRID_Y: u32 = 48;
37
38pub const MAX_MOTION_VECTOR_SEGMENTS: usize =
40 (MAX_MOTION_VECTOR_GRID_X as usize) * (MAX_MOTION_VECTOR_GRID_Y as usize);
41
42#[repr(C)]
46#[derive(Debug, Clone, Copy, PartialEq, Pod, Zeroable, Default)]
47pub struct MotionVectorVertex {
48 pub pos: [f32; 2],
49 pub color: [f32; 4],
50}
51
52pub struct MotionVectorRenderer {
54 pipeline: wgpu::RenderPipeline,
55 vertex_buffer: wgpu::Buffer,
56 vertex_count: u32,
59}
60
61impl MotionVectorRenderer {
62 pub fn new(device: &wgpu::Device, format: wgpu::TextureFormat) -> Self {
63 let shader = crate::pipeline_helpers::load_wgsl(
64 device,
65 "Motion Vector Shader",
66 include_str!("../shaders/motion_vector.wgsl"),
67 );
68
69 let initial: Vec<MotionVectorVertex> =
72 vec![MotionVectorVertex::default(); MAX_MOTION_VECTOR_SEGMENTS * 2];
73 let vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
74 label: Some("Motion Vector Vertices"),
75 contents: bytemuck::cast_slice(&initial),
76 usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
77 });
78
79 let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
80 label: Some("Motion Vector Layout"),
81 bind_group_layouts: &[],
82 immediate_size: 0,
83 });
84
85 let vertex_attributes = [
86 wgpu::VertexAttribute {
87 offset: 0,
88 shader_location: 0,
89 format: wgpu::VertexFormat::Float32x2,
90 },
91 wgpu::VertexAttribute {
92 offset: std::mem::size_of::<[f32; 2]>() as u64,
93 shader_location: 1,
94 format: wgpu::VertexFormat::Float32x4,
95 },
96 ];
97 let vertex_layout = wgpu::VertexBufferLayout {
98 array_stride: std::mem::size_of::<MotionVectorVertex>() as u64,
99 step_mode: wgpu::VertexStepMode::Vertex,
100 attributes: &vertex_attributes,
101 };
102
103 let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
104 label: Some("Motion Vector Pipeline"),
105 layout: Some(&layout),
106 vertex: wgpu::VertexState {
107 module: &shader,
108 entry_point: Some("vs_main"),
109 buffers: std::slice::from_ref(&vertex_layout),
110 compilation_options: Default::default(),
111 },
112 fragment: Some(wgpu::FragmentState {
113 module: &shader,
114 entry_point: Some("fs_main"),
115 targets: &[Some(wgpu::ColorTargetState {
116 format,
117 blend: Some(wgpu::BlendState::ALPHA_BLENDING),
118 write_mask: wgpu::ColorWrites::ALL,
119 })],
120 compilation_options: Default::default(),
121 }),
122 primitive: wgpu::PrimitiveState {
123 topology: wgpu::PrimitiveTopology::LineList,
124 ..Default::default()
125 },
126 depth_stencil: None,
127 multisample: wgpu::MultisampleState::default(),
128 multiview_mask: None,
129 cache: None,
130 });
131
132 Self {
133 pipeline,
134 vertex_buffer,
135 vertex_count: 0,
136 }
137 }
138
139 pub fn update(
150 &mut self,
151 queue: &wgpu::Queue,
152 params: MotionVectorParams,
153 warp_field: Option<WarpField<'_>>,
154 ) -> u32 {
155 let nx = params.grid_x.min(MAX_MOTION_VECTOR_GRID_X);
156 let ny = params.grid_y.min(MAX_MOTION_VECTOR_GRID_Y);
157 let alpha = params.color[3];
158 if nx == 0 || ny == 0 || alpha <= 0.0 {
159 self.vertex_count = 0;
160 return 0;
161 }
162
163 let mut vertices: Vec<MotionVectorVertex> =
164 Vec::with_capacity((nx as usize) * (ny as usize) * 2);
165 build_segments(nx, ny, params, warp_field, &mut vertices);
166
167 if !vertices.is_empty() {
168 queue.write_buffer(&self.vertex_buffer, 0, bytemuck::cast_slice(&vertices));
169 }
170 self.vertex_count = vertices.len() as u32;
171 (vertices.len() / 2) as u32
172 }
173
174 pub fn render(&self, encoder: &mut wgpu::CommandEncoder, view: &wgpu::TextureView) {
175 if self.vertex_count == 0 {
176 return;
177 }
178 let mut rp = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
179 label: Some("Motion Vector Pass"),
180 color_attachments: &[Some(wgpu::RenderPassColorAttachment {
181 view,
182 depth_slice: None,
183 resolve_target: None,
184 ops: wgpu::Operations {
185 load: wgpu::LoadOp::Load,
186 store: wgpu::StoreOp::Store,
187 },
188 })],
189 depth_stencil_attachment: None,
190 timestamp_writes: None,
191 occlusion_query_set: None,
192 multiview_mask: None,
193 });
194 rp.set_pipeline(&self.pipeline);
195 rp.set_vertex_buffer(0, self.vertex_buffer.slice(..));
196 rp.draw(0..self.vertex_count, 0..1);
197 }
198
199 pub fn segment_count(&self) -> u32 {
202 self.vertex_count / 2
203 }
204}
205
206pub struct WarpField<'a> {
213 pub cols: u32,
214 pub rows: u32,
215 pub vertices: &'a [WarpVertex],
216}
217
218fn build_segments(
223 nx: u32,
224 ny: u32,
225 params: MotionVectorParams,
226 warp_field: Option<WarpField<'_>>,
227 vertices: &mut Vec<MotionVectorVertex>,
228) {
229 let inv_nx = 1.0 / nx as f32;
230 let inv_ny = 1.0 / ny as f32;
231 let seg_len_preset = params.length * inv_nx;
235
236 for j in 0..ny {
237 for i in 0..nx {
238 let cx_preset = (i as f32 + 0.5) * inv_nx + params.dx;
239 let cy_preset = (j as f32 + 0.5) * inv_ny + params.dy;
240 let (dir_x, dir_y) = if let Some(field) = warp_field.as_ref() {
241 sample_warp_displacement(field, cx_preset, cy_preset)
242 } else {
243 (seg_len_preset, 0.0)
245 };
246 let (sx, sy) = if warp_field.is_some() {
252 (dir_x * params.length * 8.0, dir_y * params.length * 8.0)
253 } else {
254 (dir_x, dir_y)
255 };
256 let head = preset_xy_to_clip(cx_preset, cy_preset);
257 let tail = preset_xy_to_clip(cx_preset + sx, cy_preset + sy);
258 vertices.push(MotionVectorVertex {
259 pos: head,
260 color: params.color,
261 });
262 vertices.push(MotionVectorVertex {
263 pos: tail,
264 color: params.color,
265 });
266 }
267 }
268}
269
270fn sample_warp_displacement(field: &WarpField<'_>, x: f32, y: f32) -> (f32, f32) {
276 if field.cols < 2 || field.rows < 2 || field.vertices.is_empty() {
277 return (0.0, 0.0);
278 }
279 let cols_f = (field.cols - 1) as f32;
280 let rows_f = (field.rows - 1) as f32;
281 let mc = (x.clamp(0.0, 1.0)) * cols_f;
282 let mr = (y.clamp(0.0, 1.0)) * rows_f;
283 let c0 = mc.floor().clamp(0.0, cols_f - 1.0) as u32;
284 let r0 = mr.floor().clamp(0.0, rows_f - 1.0) as u32;
285 let c1 = (c0 + 1).min(field.cols - 1);
286 let r1 = (r0 + 1).min(field.rows - 1);
287 let tu = mc - c0 as f32;
288 let tv = mr - r0 as f32;
289
290 let idx = |c: u32, r: u32| -> usize { (r * field.cols + c) as usize };
291 let disp = |v: &WarpVertex| -> (f32, f32) {
292 let ox = (v.pos_clip[0] + 1.0) * 0.5;
295 let oy = (1.0 - v.pos_clip[1]) * 0.5;
296 (v.uv_warp[0] - ox, v.uv_warp[1] - oy)
297 };
298 let d00 = disp(&field.vertices[idx(c0, r0)]);
299 let d10 = disp(&field.vertices[idx(c1, r0)]);
300 let d01 = disp(&field.vertices[idx(c0, r1)]);
301 let d11 = disp(&field.vertices[idx(c1, r1)]);
302
303 let mix = |a: f32, b: f32, t: f32| a + (b - a) * t;
304 let dx = mix(mix(d00.0, d10.0, tu), mix(d01.0, d11.0, tu), tv);
305 let dy = mix(mix(d00.1, d10.1, tu), mix(d01.1, d11.1, tu), tv);
306 (dx, dy)
307}
308
309#[inline]
314fn preset_xy_to_clip(x: f32, y: f32) -> [f32; 2] {
315 [x * 2.0 - 1.0, 1.0 - y * 2.0]
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 fn enabled_params() -> MotionVectorParams {
323 MotionVectorParams {
324 grid_x: 4,
325 grid_y: 3,
326 dx: 0.0,
327 dy: 0.0,
328 length: 0.5,
329 color: [1.0, 1.0, 1.0, 1.0],
330 }
331 }
332
333 #[test]
334 fn segment_count_matches_grid_size() {
335 let mut v = Vec::new();
336 build_segments(4, 3, enabled_params(), None, &mut v);
337 assert_eq!(v.len(), 4 * 3 * 2);
338 }
339
340 #[test]
341 fn segment_endpoints_are_horizontal() {
342 let mut v = Vec::new();
343 build_segments(4, 3, enabled_params(), None, &mut v);
344 for pair in v.chunks(2) {
345 assert!((pair[0].pos[1] - pair[1].pos[1]).abs() < 1e-6);
347 assert!(pair[0].pos[0] < pair[1].pos[0]);
348 }
349 }
350
351 #[test]
352 fn dx_dy_shift_anchor() {
353 let mut a = Vec::new();
354 let mut b = Vec::new();
355 let p0 = enabled_params();
356 let mut p1 = p0;
357 p1.dx = 0.25;
358 p1.dy = -0.1;
359 build_segments(2, 2, p0, None, &mut a);
360 build_segments(2, 2, p1, None, &mut b);
361 assert_eq!(a.len(), b.len());
363 let dx_clip = 0.25 * 2.0;
366 let dy_clip = -(-0.1) * 2.0;
367 for (av, bv) in a.iter().zip(b.iter()) {
368 assert!((bv.pos[0] - av.pos[0] - dx_clip).abs() < 1e-5);
369 assert!((bv.pos[1] - av.pos[1] - dy_clip).abs() < 1e-5);
370 }
371 }
372
373 #[test]
374 fn segment_carries_color() {
375 let mut v = Vec::new();
376 let mut p = enabled_params();
377 p.color = [0.2, 0.4, 0.6, 0.5];
378 build_segments(2, 2, p, None, &mut v);
379 for vert in &v {
380 assert_eq!(vert.color, p.color);
381 }
382 }
383
384 #[test]
385 fn preset_to_clip_centre_is_origin() {
386 let p = preset_xy_to_clip(0.5, 0.5);
387 assert!(p[0].abs() < 1e-6);
388 assert!(p[1].abs() < 1e-6);
389 }
390
391 #[test]
397 fn warp_field_sampler_returns_constant_offset() {
398 let off = (0.1f32, -0.05f32);
403 let mk = |pc: [f32; 2], uv_orig: [f32; 2]| WarpVertex {
404 pos_clip: pc,
405 uv_warp: [uv_orig[0] + off.0, uv_orig[1] + off.1],
406 };
407 let verts = vec![
408 mk([-1.0, -1.0], [0.0, 1.0]), mk([1.0, -1.0], [1.0, 1.0]), mk([-1.0, 1.0], [0.0, 0.0]), mk([1.0, 1.0], [1.0, 0.0]), ];
413 let field = WarpField {
414 cols: 2,
415 rows: 2,
416 vertices: &verts,
417 };
418 for (x, y) in [(0.0, 0.0), (0.5, 0.5), (1.0, 1.0), (0.25, 0.75)] {
419 let (dx, dy) = sample_warp_displacement(&field, x, y);
420 assert!(
421 (dx - off.0).abs() < 1e-5 && (dy - off.1).abs() < 1e-5,
422 "displacement at ({x}, {y}) was ({dx}, {dy})"
423 );
424 }
425 }
426
427 #[test]
432 fn segments_pick_up_warp_field_direction() {
433 let off_y = 0.2f32;
435 let mk = |pc: [f32; 2], uv_orig: [f32; 2]| WarpVertex {
436 pos_clip: pc,
437 uv_warp: [uv_orig[0], uv_orig[1] + off_y],
438 };
439 let verts = vec![
440 mk([-1.0, -1.0], [0.0, 1.0]),
441 mk([1.0, -1.0], [1.0, 1.0]),
442 mk([-1.0, 1.0], [0.0, 0.0]),
443 mk([1.0, 1.0], [1.0, 0.0]),
444 ];
445 let field = WarpField {
446 cols: 2,
447 rows: 2,
448 vertices: &verts,
449 };
450 let mut v = Vec::new();
451 build_segments(2, 2, enabled_params(), Some(field), &mut v);
452 for pair in v.chunks(2) {
453 let dx = pair[1].pos[0] - pair[0].pos[0];
454 let dy = pair[1].pos[1] - pair[0].pos[1];
455 assert!(dx.abs() < 1e-5, "x drift: {dx}");
459 assert!(dy < -1e-3, "expected downward tail, got {dy}");
460 }
461 }
462}