1use crate::dim::lane_id;
2
3#[derive(Copy, Clone)]
4pub struct Thread;
5
6#[derive(Copy, Clone)]
7pub struct Block;
8
9#[derive(Copy, Clone)]
10pub struct Grid;
11
12#[derive(Copy, Clone)]
25pub struct ThreadWarpTile<const SIZE: usize = 32, const STRIDE: usize = 1>;
26
27trait ReduxKind {}
28
29trait NvvmReduxSyncKind<T>: ReduxKind {
30 const KIND: &'static str;
31}
32
33trait SubGroupReduceKind<T>: ReduxKind {
34 const SUB_GROUP_KIND: &'static str;
35}
36
37macro_rules! impl_redux_kind {
38 ($name: ident, $($kind: ident $type: ty),+) => {
39 pub struct $name;
40 $(impl NvvmReduxSyncKind<$type> for $name {
41 const KIND: &'static str = concat!("#nvvm<redux_kind ", stringify!($kind), ">");
42 })*
43 impl ReduxKind for $name {
44 }
45 };
46}
47
48macro_rules! impl_subgroup_kind {
49 ($name: ident, $($kind: ident $type: ty),+) => {
50 $(impl SubGroupReduceKind<$type> for $name {
51 const SUB_GROUP_KIND: &'static str = concat!("#gpu<all_reduce_op ", stringify!($kind), ">");
52 })+
53 };
54}
55
56impl_redux_kind! {ReduxAdd, add i32, add u32}
57impl_subgroup_kind!(ReduxAdd, add i32, add u32);
58impl_redux_kind! {ReduxMax, max i32, umax u32}
59impl_subgroup_kind!(ReduxMax, maxsi i32, maxui u32);
60impl_redux_kind! {ReduxMin, min i32, umin u32}
61impl_subgroup_kind!(ReduxMin, minsi i32, minui u32);
62
63impl_redux_kind! {ReduxAnd, and u32}
64impl_redux_kind! {ReduxOr, or u32}
65impl_redux_kind! {ReduxXor, xor u32}
66
67#[expect(private_bounds)]
68pub trait WarpReduceOp<T, Op: ReduxKind> {
69 fn redux(&self, op: Op, value: T) -> T;
70}
71
72pub trait CGOperations {
73 fn thread_rank(&self) -> u32;
74}
75
76impl<const SIZE: usize, const STRIDE: usize> ThreadWarpTile<SIZE, STRIDE> {
88 #[expect(clippy::cast_possible_truncation)]
89 pub const CHECKED_SIZE: u32 = {
90 assert!(
91 SIZE == 1 || SIZE == 2 || SIZE == 4 || SIZE == 8 || SIZE == 16 || SIZE == 32,
92 "SIZE must be <= 32 and be a power of 2"
93 );
94 assert!(STRIDE <= 32 / SIZE && STRIDE >= 1, "STRIDE must be >= 1 and <= 32 / SIZE");
95 SIZE as u32
96 };
97}
98
99impl<const SIZE: usize> ThreadWarpTile<SIZE, 1> {
101 pub const BASE_THREAD_MASK: u32 = { (1u32 << Self::CHECKED_SIZE) - 1 };
102 pub const LANE_MASK: u32 = Self::CHECKED_SIZE - 1;
103
104 pub const SHIFT_COUNT: u32 = { 5 - Self::CHECKED_SIZE.trailing_zeros() };
107
108 #[gpu_macros::device]
109 #[inline(always)]
110 pub const fn size(&self) -> u32 {
111 Self::CHECKED_SIZE as _
112 }
113
114 pub fn meta_group_size(&self) -> u32 {
115 Self::_meta_group_size()
116 }
117
118 pub(crate) fn _meta_group_size() -> u32 {
119 crate::dim::block_size() / Self::CHECKED_SIZE
120 }
121
122 #[gpu_macros::device]
123 #[inline(always)]
124 pub fn subgroup_id(&self) -> u32 {
125 Self::_subgroup_id()
126 }
127
128 #[gpu_macros::device]
129 #[inline(always)]
130 pub(crate) fn _subgroup_id() -> u32 {
131 Block.thread_rank() >> (5 - Self::SHIFT_COUNT)
132 }
133
134 #[gpu_macros::device]
135 #[inline(always)]
136 pub(crate) fn _thread_rank() -> u32 {
137 Block.thread_rank() & Self::LANE_MASK
138 }
139
140 pub fn thread_mask(&self) -> u32 {
147 Self::BASE_THREAD_MASK << (lane_id() & !Self::LANE_MASK)
148 }
149
150 #[gpu_macros::device]
151 #[inline(always)]
152 fn reduce_with_shuffle(&self, value: f32, op: impl Fn(f32, f32) -> f32) -> f32 {
153 let mut offset = Self::CHECKED_SIZE >> 1;
154 let mut value = value;
155 while offset > 0 {
156 let (peer_val, _) = crate::shuffle!(xor, value, offset, Self::CHECKED_SIZE);
157 value = op(value, peer_val);
158 offset /= 2;
159 }
160 value
161 }
162
163 #[gpu_macros::device]
165 #[inline(always)]
166 #[expect(private_bounds)]
167 pub fn nvcc_redux_sync<Op: NvvmReduxSyncKind<T>, T>(&self, _op: Op, value: T) -> T {
168 _redux_sync::<T>(value, self.thread_mask(), <Op as NvvmReduxSyncKind<T>>::KIND)
169 }
170 }
199
200impl<const SIZE: usize, T, Op: ReduxKind> WarpReduceOp<T, Op> for ThreadWarpTile<SIZE, 1>
214where
215 Op: NvvmReduxSyncKind<T>,
216{
217 fn redux(&self, op: Op, value: T) -> T {
218 self.nvcc_redux_sync(op, value)
219 }
220}
221
222impl !SubGroupReduceKind<f32> for ReduxAdd {}
243impl !NvvmReduxSyncKind<f32> for ReduxAdd {}
244
245impl<const SIZE: usize> WarpReduceOp<f32, ReduxAdd> for ThreadWarpTile<SIZE> {
246 fn redux(&self, _op: ReduxAdd, value: f32) -> f32 {
247 self.reduce_with_shuffle(value, |a, b| a + b)
248 }
249}
250
251impl !SubGroupReduceKind<f32> for ReduxMax {}
252impl !NvvmReduxSyncKind<f32> for ReduxMax {}
253
254impl<const SIZE: usize> WarpReduceOp<f32, ReduxMax> for ThreadWarpTile<SIZE> {
255 fn redux(&self, _op: ReduxMax, value: f32) -> f32 {
256 self.reduce_with_shuffle(value, |a, b| a.max(b))
257 }
258}
259
260impl<const SIZE: usize, const STRIDE: usize> ThreadWarpTile<SIZE, STRIDE> {
261 #[rustc_diagnostic_item = "gpu::subgroup_reduce"]
264 #[gpu_macros::device]
265 #[inline(never)]
266 pub fn _subgroup_reduce<T>(_value: T, _op: &'static str) -> T {
267 unimplemented!()
268 }
269
270 #[gpu_macros::device]
272 #[inline(always)]
273 #[expect(private_bounds)]
274 pub fn subgroup_reduce<Op, T>(self, _op: Op, value: T) -> T
275 where
276 Op: SubGroupReduceKind<T>,
277 {
278 Self::_subgroup_reduce::<T>(value, Op::SUB_GROUP_KIND)
279 }
280}
281
282#[rustc_diagnostic_item = "nvvm::redux_sync"]
283#[gpu_macros::device]
284#[inline(never)]
285pub fn _redux_sync<T>(_value: T, _mask: u32, _op: &'static str) -> T {
286 unimplemented!()
287}
288
289#[rustc_diagnostic_item = "gpu::shuffle"]
290#[gpu_macros::device]
291#[inline(never)]
292pub fn _shuffle<T>(_value: T, _offset: u32, _width: u32, _op: &'static str) -> (T, bool) {
293 unimplemented!()
294}
295
296#[macro_export]
298macro_rules! shuffle {
299 (xor, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode xor>") }};
300 (up, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode up>") }};
301 (down, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode down>") }};
302 (idx, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode idx>") }};
303}
304
305impl<const SIZE: usize> CGOperations for ThreadWarpTile<SIZE> {
306 #[gpu_macros::device]
307 #[inline(always)]
308 fn thread_rank(&self) -> u32 {
309 Self::_thread_rank() as _
310 }
311}
312
313impl Block {
314 #[gpu_macros::device]
315 #[inline(always)]
316 pub fn thread_rank(&self) -> u32 {
317 crate::dim::thread_id::<crate::dim::DimX>()
318 + crate::dim::block_dim::<crate::dim::DimX>()
319 * (crate::dim::thread_id::<crate::dim::DimY>()
320 + crate::dim::block_dim::<crate::dim::DimY>()
321 * crate::dim::thread_id::<crate::dim::DimZ>())
322 }
323}