1#[expect(private_bounds)]
2pub trait VecTypeTrait: Default + Copy + core::ops::Add<Output = Self> + VecTypeInternal {
3 type InnerType: Copy
4 + Default
5 + core::ops::IndexMut<usize, Output = Self::Elem>
6 + core::iter::IntoIterator<Item = Self::Elem>;
7 type Elem: Copy + Default + core::ops::Add<Output = Self::Elem>;
8
9 fn new(data: Self::InnerType) -> Self;
10
11 fn data(&self) -> &Self::InnerType;
12
13 fn iter_mut(&mut self) -> core::slice::IterMut<'_, Self::Elem>;
14
15 fn iter(&self) -> core::slice::Iter<'_, Self::Elem>;
16}
17
18trait VecTypeInternal {}
20
21macro_rules! impl_floatn_from {
22 ($name: ident, $inner: ident, $align: literal, $base: ty, $N: literal) => {
23 #[derive(Clone, Copy, PartialEq, Debug, Default)]
24 #[repr(C, align($align))]
25 pub struct $inner {
26 pub data: [$base; $N],
27 }
28 impl From<[$base; $N]> for $inner {
29 #[inline(always)]
30 #[gpu_codegen::device]
31 fn from(v: [$base; $N]) -> Self {
32 Self { data: v }
33 }
34 }
35 pub type $name = VecType<$inner>;
36 impl VecTypeTrait for $inner {
37 type Elem = $base;
38 type InnerType = [$base; $N];
39 #[inline(always)]
40 #[gpu_codegen::device]
41 fn new(data: Self::InnerType) -> Self {
42 Self { data }
43 }
44
45 #[inline(always)]
46 #[gpu_codegen::device]
47 fn data(&self) -> &Self::InnerType {
48 &self.data
49 }
50
51 #[inline(always)]
52 #[gpu_codegen::device]
53 fn iter_mut(&mut self) -> core::slice::IterMut<'_, Self::Elem> {
54 self.data.iter_mut()
55 }
56
57 #[inline(always)]
58 #[gpu_codegen::device]
59 fn iter(&self) -> core::slice::Iter<'_, Self::Elem> {
60 self.data.iter()
61 }
62 }
63
64 impl VecTypeInternal for $inner {}
65
66 impl core::ops::Add for $inner {
67 type Output = Self;
68 #[inline(always)]
70 #[gpu_codegen::device]
71 fn add(self, rhs: Self) -> Self {
72 let mut data: [$base; $N] = [Default::default(); $N];
73 for i in 0..$N {
74 data[i] = self.data[i] + rhs.data[i];
75 }
76 $inner { data }
77 }
78 }
79
80 impl core::ops::Index<usize> for $inner {
81 type Output = $base;
82 #[inline(always)]
83 #[gpu_codegen::device]
84 fn index(&self, index: usize) -> &Self::Output {
85 &self.data[index]
86 }
87 }
88
89 impl core::ops::IndexMut<usize> for $inner {
90 #[inline(always)]
91 #[gpu_codegen::device]
92 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
93 &mut self.data[index]
94 }
95 }
96 };
97}
98
99#[derive(Clone, Copy, PartialEq, Debug)]
104pub struct VecType<T: VecTypeTrait> {
105 val: T,
106}
107
108impl<T: VecTypeTrait> Default for VecType<T> {
109 #[inline(always)]
110 #[gpu_codegen::device]
111 fn default() -> Self {
112 Self { val: T::new(Default::default()) }
113 }
114}
115
116impl<T: VecTypeTrait> VecType<T> {
117 #[inline(always)]
118 #[gpu_codegen::device]
119 pub fn new(val: T::InnerType) -> Self {
120 VecType { val: T::new(val) }
121 }
122}
123
124impl<T: VecTypeTrait> core::ops::Add for VecType<T> {
125 type Output = Self;
126 #[inline(always)]
128 #[gpu_codegen::device]
129 fn add(self, rhs: Self) -> Self {
130 let val = self.val + rhs.val;
131 VecType { val }
132 }
133}
134
135impl<T: VecTypeTrait> core::ops::Deref for VecType<T> {
136 type Target = T;
137 #[inline(always)]
138 #[gpu_codegen::device]
139 fn deref(&self) -> &T {
140 &self.val
141 }
142}
143
144impl<T: VecTypeTrait> core::ops::DerefMut for VecType<T> {
145 #[inline(always)]
146 #[gpu_codegen::device]
147 fn deref_mut(&mut self) -> &mut T {
148 &mut self.val
149 }
150}
151
152impl_floatn_from!(Float2, Float2Inner, 8, f32, 2);
153impl_floatn_from!(Float4, Float4Inner, 16, f32, 4);
154impl_floatn_from!(Float8, Float8Inner, 32, f32, 8);
155
156impl_floatn_from!(U32_2, U32_2Inner, 8, u32, 2);
157impl_floatn_from!(U32_4, U32_4Inner, 16, u32, 4);
158impl_floatn_from!(U32_8, U32_8Inner, 32, u32, 8);
159
160pub trait VecFlatten<T2> {
161 fn flatten(&self) -> &[T2];
162}
163
164impl<T, T2> VecFlatten<T2> for [VecType<T>]
172where
173 T: VecTypeTrait<Elem = T2>,
174{
175 fn flatten(&self) -> &[T2] {
176 assert!(size_of::<T>() >= size_of::<T2>(), "T2 is larger than T");
178 assert!(align_of::<T>() >= align_of::<T2>(), "T2 has stricter alignment than T");
179 unsafe {
180 &*core::ptr::slice_from_raw_parts_mut(
181 self.as_ptr() as _,
182 self.len() * size_of::<T>() / size_of::<T2>(),
183 )
184 }
185 }
186}