gpu/
cg.rs

1use crate::dim::lane_id;
2
3#[derive(Copy, Clone)]
4pub struct Thread;
5
6#[derive(Copy, Clone)]
7pub struct Block;
8
9#[derive(Copy, Clone)]
10pub struct Grid;
11
12/// Similar a thread block tile in a GPU kernel.
13/// But the SIZE <= warp size (e.g., 32 for NVIDIA GPUs).
14/// If SIZE = 8, stride = 4, then the clusters will be:
15/// [0, 4, 8, 12, 16, 20, 24, 28]
16/// [1, 5, 9, 13, 17, 21, 25, 29]
17/// [2, 6, 10, 14, 18, 22, 26, 30]
18/// [3, 7, 11, 15, 19, 23, 27, 31]
19/// If SIZE = 8, stride = 1, then the clusters will be:
20/// [0, 1, 2, 3, 4, 5, 6, 7]
21/// [8, 9, 10, 11, 12, 13, 14, 15]
22/// [16, 17, 18, 19, 20, 21, 22, 23]
23/// [24, 25, 26, 27, 28, 29, 30, 31]
24#[derive(Copy, Clone)]
25pub struct ThreadWarpTile<const SIZE: usize = 32, const STRIDE: usize = 1>;
26
27trait ReduxKind {}
28
29trait NvvmReduxSyncKind<T>: ReduxKind {
30    const KIND: &'static str;
31}
32
33trait SubGroupReduceKind<T>: ReduxKind {
34    const SUB_GROUP_KIND: &'static str;
35}
36
37macro_rules! impl_redux_kind {
38    ($name: ident, $($kind: ident $type: ty),+) => {
39        pub struct $name;
40        $(impl NvvmReduxSyncKind<$type> for $name {
41            const KIND: &'static str = concat!("#nvvm<redux_kind ", stringify!($kind), ">");
42        })*
43        impl ReduxKind for $name {
44        }
45    };
46}
47
48macro_rules! impl_subgroup_kind {
49    ($name: ident, $($kind: ident $type: ty),+) => {
50        $(impl SubGroupReduceKind<$type> for $name {
51            const SUB_GROUP_KIND: &'static str = concat!("#gpu<all_reduce_op ", stringify!($kind), ">");
52        })+
53    };
54}
55
56impl_redux_kind! {ReduxAdd, add i32, add u32}
57impl_subgroup_kind!(ReduxAdd, add i32, add u32);
58impl_redux_kind! {ReduxMax, max i32, umax u32}
59impl_subgroup_kind!(ReduxMax, maxsi i32, maxui u32);
60impl_redux_kind! {ReduxMin, min i32, umin u32}
61impl_subgroup_kind!(ReduxMin, minsi i32, minui u32);
62
63impl_redux_kind! {ReduxAnd, and u32}
64impl_redux_kind! {ReduxOr, or u32}
65impl_redux_kind! {ReduxXor, xor u32}
66
67#[expect(private_bounds)]
68pub trait WarpReduceOp<T, Op: ReduxKind> {
69    fn redux(&self, op: Op, value: T) -> T;
70}
71
72pub trait CGOperations {
73    fn thread_rank(&self) -> u32;
74}
75
76///
77///
78/// ```rust
79/// let warp = gpu::cg::ThreadWarpTile::<16>;
80/// let size = warp.size();
81/// ```
82///
83///  ```rust,compile_fail,E0080
84/// let warp = gpu::cg::ThreadWarpTile::<3>;
85/// let size = warp.size();
86/// ```
87impl<const SIZE: usize, const STRIDE: usize> ThreadWarpTile<SIZE, STRIDE> {
88    #[expect(clippy::cast_possible_truncation)]
89    pub const CHECKED_SIZE: u32 = {
90        assert!(
91            SIZE == 1 || SIZE == 2 || SIZE == 4 || SIZE == 8 || SIZE == 16 || SIZE == 32,
92            "SIZE must be <= 32 and be a power of 2"
93        );
94        assert!(STRIDE <= 32 / SIZE && STRIDE >= 1, "STRIDE must be >= 1 and <= 32 / SIZE");
95        SIZE as u32
96    };
97}
98
99/// Implement simple flexible warp with STRIDE = 1.
100impl<const SIZE: usize> ThreadWarpTile<SIZE, 1> {
101    pub const BASE_THREAD_MASK: u32 = { (1u32 << Self::CHECKED_SIZE) - 1 };
102    pub const LANE_MASK: u32 = Self::CHECKED_SIZE - 1;
103
104    // warp size == 32 => 0
105    // warp size == 1 => 5
106    pub const SHIFT_COUNT: u32 = { 5 - Self::CHECKED_SIZE.trailing_zeros() };
107
108    #[gpu_macros::device]
109    #[inline(always)]
110    pub const fn size(&self) -> u32 {
111        Self::CHECKED_SIZE as _
112    }
113
114    pub fn meta_group_size(&self) -> u32 {
115        Self::_meta_group_size()
116    }
117
118    pub(crate) fn _meta_group_size() -> u32 {
119        crate::dim::block_size() / Self::CHECKED_SIZE
120    }
121
122    #[gpu_macros::device]
123    #[inline(always)]
124    pub fn subgroup_id(&self) -> u32 {
125        Self::_subgroup_id()
126    }
127
128    #[gpu_macros::device]
129    #[inline(always)]
130    pub(crate) fn _subgroup_id() -> u32 {
131        Block.thread_rank() >> (5 - Self::SHIFT_COUNT)
132    }
133
134    #[gpu_macros::device]
135    #[inline(always)]
136    pub(crate) fn _thread_rank() -> u32 {
137        Block.thread_rank() & Self::LANE_MASK
138    }
139
140    /// E.g., when SIZE = 8,
141    /// lane_id -> mask
142    /// 0 -> 0xff
143    /// 1 -> 0xff
144    /// 8 -> 0xff00
145    /// 9 -> 0xff00
146    pub fn thread_mask(&self) -> u32 {
147        Self::BASE_THREAD_MASK << (lane_id() & !Self::LANE_MASK)
148    }
149
150    #[gpu_macros::device]
151    #[inline(always)]
152    fn reduce_with_shuffle(&self, value: f32, op: impl Fn(f32, f32) -> f32) -> f32 {
153        let mut offset = Self::CHECKED_SIZE >> 1;
154        let mut value = value;
155        while offset > 0 {
156            let (peer_val, _) = crate::shuffle!(xor, value, offset, Self::CHECKED_SIZE);
157            value = op(value, peer_val);
158            offset /= 2;
159        }
160        value
161    }
162
163    // Hardware-specific warp reduce.
164    #[gpu_macros::device]
165    #[inline(always)]
166    #[expect(private_bounds)]
167    pub fn nvcc_redux_sync<Op: NvvmReduxSyncKind<T>, T>(&self, _op: Op, value: T) -> T {
168        _redux_sync::<T>(value, self.thread_mask(), <Op as NvvmReduxSyncKind<T>>::KIND)
169    }
170    /*#[gpu_macros::device]
171    #[inline(always)]
172    pub fn run_on_lane_0<T>(self, slice: &mut [T], f: impl FnOnce(&mut T) + Clone + Send) {
173        if self.lane_id() == 0 {
174            // Build ref from the slice on the wrap
175            let threads_per_block = block_id::<X>()
176                * block_id::<Y>()
177                * block_id::<Z>();
178            // TODO: Although not exactly necessary, shall we enforce threads_per_block % SIZE == 0?
179            let offset = (threads_per_block / SIZE)
180                * (crate::block_id::<X>()
181                    + crate::grid_dim::<X>()
182                        * (crate::block_id::<Y>()
183                            + crate::grid_dim::<Y>()
184                                * crate::block_id::<Z>()))
185                + self.subgroup_id();
186
187            // SAFETY: The offset is unique per Warp not per GPU thread. Although
188            // multiple threads in the same warp may access the same memory
189            // location, the `run_on_lane_0` function ensures only lane 0 (a
190            // specific thread inside a warp) will execute the closure.
191            // Thus it is safe to use it here.
192            let local_val = unsafe { crate::subslice_mut(slice, offset, 1) };
193
194            // Call exec_on_thread_0
195            f(&mut local_val[0]);
196        }
197    }*/
198}
199
200/*impl<T, Op: ReduxKind> WarpReduceOp<T, Op> for ThreadWarpTile<32, 1>
201where
202    Op: SubGroupReduceKind<T>,
203    Self: FullWarp
204{
205    fn redux(&self, op: Op, value: T) -> T {
206        self.subgroup_reduce(op, value)
207    }
208}
209*/
210
211/// Ideally, we should use `SubGroupReduceKind` here, but MLIR support is limited.
212/// Let user use subgroup_reduce directly.
213impl<const SIZE: usize, T, Op: ReduxKind> WarpReduceOp<T, Op> for ThreadWarpTile<SIZE, 1>
214where
215    Op: NvvmReduxSyncKind<T>,
216{
217    fn redux(&self, op: Op, value: T) -> T {
218        self.nvcc_redux_sync(op, value)
219    }
220}
221
222/*
223impl<const SIZE: usize> WarpReduceOp<f32, ReduxMax> for ThreadWarpTile<SIZE> {
224    // LLVM 20.1.8 does not support redux.sync.max.f32
225    // This requires sm_100f and so reserved only for future arch even after Hopper.
226    fn redux(&self, _op: ReduxMax, value: f32) -> f32 {
227        let mut ret: f32;
228        unsafe {
229            core::arch::asm!(
230                "redux.sync.max.f32 {0:e}, {1:e}, {2:e};",
231                out(reg) ret,
232                in(reg) value,
233                in(reg) self.thread_mask()
234            );
235        }
236        ret
237    }
238}
239*/
240
241/// PTX does not support `redux.sync.add.f32` and so implement it via shuffle.
242impl !SubGroupReduceKind<f32> for ReduxAdd {}
243impl !NvvmReduxSyncKind<f32> for ReduxAdd {}
244
245impl<const SIZE: usize> WarpReduceOp<f32, ReduxAdd> for ThreadWarpTile<SIZE> {
246    fn redux(&self, _op: ReduxAdd, value: f32) -> f32 {
247        self.reduce_with_shuffle(value, |a, b| a + b)
248    }
249}
250
251impl !SubGroupReduceKind<f32> for ReduxMax {}
252impl !NvvmReduxSyncKind<f32> for ReduxMax {}
253
254impl<const SIZE: usize> WarpReduceOp<f32, ReduxMax> for ThreadWarpTile<SIZE> {
255    fn redux(&self, _op: ReduxMax, value: f32) -> f32 {
256        self.reduce_with_shuffle(value, |a, b| a.max(b))
257    }
258}
259
260impl<const SIZE: usize, const STRIDE: usize> ThreadWarpTile<SIZE, STRIDE> {
261    /// Reduce by hardware-defined warp.
262    /// For now, it only supports `i32` or `u32` types.
263    #[rustc_diagnostic_item = "gpu::subgroup_reduce"]
264    #[gpu_macros::device]
265    #[inline(never)]
266    pub fn _subgroup_reduce<T>(_value: T, _op: &'static str) -> T {
267        unimplemented!()
268    }
269
270    /// Reduce by software-defined warp.
271    #[gpu_macros::device]
272    #[inline(always)]
273    #[expect(private_bounds)]
274    pub fn subgroup_reduce<Op, T>(self, _op: Op, value: T) -> T
275    where
276        Op: SubGroupReduceKind<T>,
277    {
278        Self::_subgroup_reduce::<T>(value, Op::SUB_GROUP_KIND)
279    }
280}
281
282#[rustc_diagnostic_item = "nvvm::redux_sync"]
283#[gpu_macros::device]
284#[inline(never)]
285pub fn _redux_sync<T>(_value: T, _mask: u32, _op: &'static str) -> T {
286    unimplemented!()
287}
288
289#[rustc_diagnostic_item = "gpu::shuffle"]
290#[gpu_macros::device]
291#[inline(never)]
292pub fn _shuffle<T>(_value: T, _offset: u32, _width: u32, _op: &'static str) -> (T, bool) {
293    unimplemented!()
294}
295
296/// define a macro to use shuffle with a specific mode xor, up, down.
297#[macro_export]
298macro_rules! shuffle {
299    (xor, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode xor>") }};
300    (up, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode up>") }};
301    (down, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode down>") }};
302    (idx, $value:expr, $offset:expr, $width:expr) => {{ $crate::cg::_shuffle($value, $offset, $width, "#gpu<shuffle_mode idx>") }};
303}
304
305impl<const SIZE: usize> CGOperations for ThreadWarpTile<SIZE> {
306    #[gpu_macros::device]
307    #[inline(always)]
308    fn thread_rank(&self) -> u32 {
309        Self::_thread_rank() as _
310    }
311}
312
313impl Block {
314    #[gpu_macros::device]
315    #[inline(always)]
316    pub fn thread_rank(&self) -> u32 {
317        crate::dim::thread_id::<crate::dim::DimX>()
318            + crate::dim::block_dim::<crate::dim::DimX>()
319                * (crate::dim::thread_id::<crate::dim::DimY>()
320                    + crate::dim::block_dim::<crate::dim::DimY>()
321                        * crate::dim::thread_id::<crate::dim::DimZ>())
322    }
323}