gpu/
vector.rs

1#[expect(private_bounds)]
2pub trait VecTypeTrait: Default + Copy + core::ops::Add<Output = Self> + VecTypeInternal {
3    type InnerType: Copy
4        + Default
5        + core::ops::IndexMut<usize, Output = Self::Elem>
6        + core::iter::IntoIterator<Item = Self::Elem>;
7    type Elem: Copy + Default + core::ops::Add<Output = Self::Elem>;
8
9    fn new(data: Self::InnerType) -> Self;
10
11    fn data(&self) -> &Self::InnerType;
12
13    fn iter_mut(&mut self) -> core::slice::IterMut<'_, Self::Elem>;
14
15    fn iter(&self) -> core::slice::Iter<'_, Self::Elem>;
16}
17
18/// Internal trait to restrict VecType generic parameter.
19trait VecTypeInternal {}
20
21macro_rules! impl_floatn_from {
22    ($name: ident, $inner: ident, $align: literal, $base: ty, $N: literal) => {
23        #[derive(Clone, Copy, PartialEq, Debug, Default)]
24        #[repr(C, align($align))]
25        pub struct $inner {
26            pub data: [$base; $N],
27        }
28        impl From<[$base; $N]> for $inner {
29            #[inline(always)]
30            #[gpu_codegen::device]
31            fn from(v: [$base; $N]) -> Self {
32                Self { data: v }
33            }
34        }
35        pub type $name = VecType<$inner>;
36        impl VecTypeTrait for $inner {
37            type Elem = $base;
38            type InnerType = [$base; $N];
39            #[inline(always)]
40            #[gpu_codegen::device]
41            fn new(data: Self::InnerType) -> Self {
42                Self { data }
43            }
44
45            #[inline(always)]
46            #[gpu_codegen::device]
47            fn data(&self) -> &Self::InnerType {
48                &self.data
49            }
50
51            #[inline(always)]
52            #[gpu_codegen::device]
53            fn iter_mut(&mut self) -> core::slice::IterMut<'_, Self::Elem> {
54                self.data.iter_mut()
55            }
56
57            #[inline(always)]
58            #[gpu_codegen::device]
59            fn iter(&self) -> core::slice::Iter<'_, Self::Elem> {
60                self.data.iter()
61            }
62        }
63
64        impl VecTypeInternal for $inner {}
65
66        impl core::ops::Add for $inner {
67            type Output = Self;
68            /// Adds two vectors element-wise.
69            #[inline(always)]
70            #[gpu_codegen::device]
71            fn add(self, rhs: Self) -> Self {
72                let mut data: [$base; $N] = [Default::default(); $N];
73                for i in 0..$N {
74                    data[i] = self.data[i] + rhs.data[i];
75                }
76                $inner { data }
77            }
78        }
79
80        impl core::ops::Index<usize> for $inner {
81            type Output = $base;
82            #[inline(always)]
83            #[gpu_codegen::device]
84            fn index(&self, index: usize) -> &Self::Output {
85                &self.data[index]
86            }
87        }
88
89        impl core::ops::IndexMut<usize> for $inner {
90            #[inline(always)]
91            #[gpu_codegen::device]
92            fn index_mut(&mut self, index: usize) -> &mut Self::Output {
93                &mut self.data[index]
94            }
95        }
96    };
97}
98
99/// A vector type with `N` elements of type `T::Elem`.
100/// T must be aligned by the size of the vector.
101/// Since we cannot do `repr(align(N * size_of::<T::Elem>))` yet,
102/// we define separate types for different sizes.
103#[derive(Clone, Copy, PartialEq, Debug)]
104pub struct VecType<T: VecTypeTrait> {
105    val: T,
106}
107
108impl<T: VecTypeTrait> Default for VecType<T> {
109    #[inline(always)]
110    #[gpu_codegen::device]
111    fn default() -> Self {
112        Self { val: T::new(Default::default()) }
113    }
114}
115
116impl<T: VecTypeTrait> VecType<T> {
117    #[inline(always)]
118    #[gpu_codegen::device]
119    pub fn new(val: T::InnerType) -> Self {
120        VecType { val: T::new(val) }
121    }
122}
123
124impl<T: VecTypeTrait> core::ops::Add for VecType<T> {
125    type Output = Self;
126    /// Adds two vectors element-wise.
127    #[inline(always)]
128    #[gpu_codegen::device]
129    fn add(self, rhs: Self) -> Self {
130        let val = self.val + rhs.val;
131        VecType { val }
132    }
133}
134
135impl<T: VecTypeTrait> core::ops::Deref for VecType<T> {
136    type Target = T;
137    #[inline(always)]
138    #[gpu_codegen::device]
139    fn deref(&self) -> &T {
140        &self.val
141    }
142}
143
144impl<T: VecTypeTrait> core::ops::DerefMut for VecType<T> {
145    #[inline(always)]
146    #[gpu_codegen::device]
147    fn deref_mut(&mut self) -> &mut T {
148        &mut self.val
149    }
150}
151
152impl_floatn_from!(Float2, Float2Inner, 8, f32, 2);
153impl_floatn_from!(Float4, Float4Inner, 16, f32, 4);
154impl_floatn_from!(Float8, Float8Inner, 32, f32, 8);
155
156impl_floatn_from!(U32_2, U32_2Inner, 8, u32, 2);
157impl_floatn_from!(U32_4, U32_4Inner, 16, u32, 4);
158impl_floatn_from!(U32_8, U32_8Inner, 32, u32, 8);
159
160pub trait VecFlatten<T2> {
161    fn flatten(&self) -> &[T2];
162}
163
164/// Useful to optimize code with vector load/store.
165/// If length of the slice is not a multiple of N,
166/// the remaining elements will be ignored.
167///
168/// # Safety
169/// This is safe since VecType<T> always has a layout \
170/// compatible with T::Elem array.
171impl<T, T2> VecFlatten<T2> for [VecType<T>]
172where
173    T: VecTypeTrait<Elem = T2>,
174{
175    fn flatten(&self) -> &[T2] {
176        // SAFETY: the returned slice will be at same size or shorter, so it is safe.
177        assert!(size_of::<T>() >= size_of::<T2>(), "T2 is larger than T");
178        assert!(align_of::<T>() >= align_of::<T2>(), "T2 has stricter alignment than T");
179        unsafe {
180            &*core::ptr::slice_from_raw_parts_mut(
181                self.as_ptr() as _,
182                self.len() * size_of::<T>() / size_of::<T2>(),
183            )
184        }
185    }
186}