gpu/
chunk.rs

1use core::marker::PhantomData;
2use core::ops::{Index, IndexMut};
3
4use num_traits::AsPrimitive;
5
6use crate::chunk_scope::{
7    ChainedMap, ChainedScope, ChunkScope, Grid, Grid2ThreadScope, TID_MAX_LEN, Thread,
8};
9use crate::{GpuGlobal, assert_ptr};
10
11/// Thread unique mapping trait
12///
13/// This trait guarantees that each thread produces a unique index mapping,
14/// so no two distinct threads map to the same global index.
15///
16/// # Type Parameters
17/// - `CS`: The memory space: GlobalMemScope or SharedMemScope
18///
19/// # Safety
20/// Implementors must ensure:
21/// ```text
22/// forall |idx1, idx2, thread_ids1, thread_ids2|
23///     thread_ids1 != thread_ids2 ==>
24///         map(idx1, thread_ids1) !=  map(idx2, thread_ids2)
25/// ```
26pub unsafe trait ScopeUniqueMap<CS: ChunkScope>: Clone {
27    type IndexType: Copy + 'static;
28    type GlobalIndexType: AsPrimitive<usize>;
29    #[inline]
30    #[gpu_codegen::device]
31    fn precondition(&self) -> bool {
32        true
33    }
34
35    /// Returns the extra precondition of indexing operation and the global
36    /// index. Without providing extra precondition, index will always check the
37    /// OOB error with global idx.
38    #[gpu_codegen::device]
39    fn map(
40        &self,
41        idx: Self::IndexType,
42        thread_ids: [u32; TID_MAX_LEN],
43    ) -> (bool, Self::GlobalIndexType);
44}
45
46/// Provide local_to_global_index for chunking.
47/// This is a private trait that should not be used outside this crate.
48pub(crate) trait ScopeUniqueMapProvidedMethods<CS: ChunkScope>: ScopeUniqueMap<CS> {
49    #[inline]
50    #[gpu_codegen::device]
51    fn local_to_global_index(&self, idx: Self::IndexType) -> (bool, Self::GlobalIndexType) {
52        self.map(idx, CS::thread_ids())
53    }
54}
55
56impl<T: ScopeUniqueMap<CS>, CS: ChunkScope> ScopeUniqueMapProvidedMethods<CS> for T {}
57
58/// Represent a chunk of global memory that is uniquely mapped to each thread group.
59/// It supports both continuous and non-continuous mapping strategies.
60/// - T: element type
61/// - Map: mapping strategy type
62/// - 'a: lifetime of the underlying slice
63/// - map_params: parameters for the mapping strategy
64pub struct GlobalGroupChunk<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> {
65    data: &'a mut [T], // Must be private.
66    pub map_params: Map,
67    dummy: PhantomData<CS>,
68}
69
70pub type GlobalThreadChunk<'a, T, Map> = GlobalGroupChunk<'a, T, Grid2ThreadScope, Map>;
71
72/// Creating global chunk from GpuGlobal.
73impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> GlobalGroupChunk<'a, T, CS, Map>
74where
75    CS: ChunkScope<FromScope = Grid>,
76{
77    #[inline]
78    #[rustc_diagnostic_item = "gpu::chunk_mut"]
79    #[gpu_codegen::device]
80    #[gpu_codegen::sync_data(0, 1)]
81    /// TODO: We will prevent rechunking of global mem in the mir_analysis.
82    /// For now, we just leave it to the user.
83    pub fn new(global: GpuGlobal<'a, [T]>, map_params: Map) -> Self {
84        if !map_params.precondition() {
85            core::intrinsics::abort();
86        }
87        Self { data: global.data, map_params, dummy: PhantomData }
88    }
89
90    /// In some cases, passing GlobalThreadChunk from host to device is more
91    /// convenient. Due to unknown optimization-related factors, passing
92    /// GlobalThreadChunk to kernel function may make the kernel faster (see
93    /// examples/matmul). In addition, directly passing GlobalThreadChunk from
94    /// host to device naturally avoids the problem of rechunking the global
95    /// mem.
96    #[cfg(not(feature = "codegen_tests"))]
97    pub fn new_from_host<'b: 'a>(
98        slice: &'a mut cuda_bindings::TensorViewMut<'b, [T]>,
99        map_params: Map,
100    ) -> Self {
101        unsafe {
102            Self {
103                data: &mut *(slice.as_flat_devptr() as *mut [T]),
104                map_params,
105                dummy: PhantomData,
106            }
107        }
108    }
109}
110
111/// Creating global chunk from GpuGlobal.
112impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> GlobalGroupChunk<'a, T, CS, Map> {
113    /// Convert GlobalGroupChunk to another GlobalGroupChunk with different ChunkScope and Map.
114    /// See `ChunkScope` for more details about chunk scope.
115    #[gpu_codegen::device]
116    #[gpu_codegen::sync_data(2)]
117    pub fn chunk_to_scope<CS2: ChunkScope, Map2: ScopeUniqueMap<CS2>>(
118        self,
119        _cs: CS2,
120        map: Map2,
121    ) -> GlobalGroupChunk<'a, T, ChainedScope<CS, CS2>, ChainedMap<CS, CS2, Map, Map2>>
122    where
123        Map: ScopeUniqueMap<CS>,
124        CS: ChunkScope<ToScope = CS2::FromScope>,
125        Map2::GlobalIndexType: AsPrimitive<Map::IndexType>,
126    {
127        GlobalGroupChunk {
128            data: self.data,
129            map_params: ChainedMap::new(self.map_params, map),
130            dummy: PhantomData,
131        }
132    }
133
134    #[gpu_codegen::device]
135    #[inline]
136    pub fn local2global(
137        &self,
138        idx: <Map as ScopeUniqueMap<CS>>::IndexType,
139    ) -> Map::GlobalIndexType {
140        self.map_params.local_to_global_index(idx).1
141    }
142}
143
144// Read-only access is always allowed.
145impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> Index<Map::IndexType>
146    for GlobalGroupChunk<'a, T, CS, Map>
147{
148    type Output = T;
149
150    #[inline(always)]
151    #[gpu_codegen::device]
152    fn index(&self, idx: Map::IndexType) -> &T {
153        let (idx_precondition, idx) = self.map_params.local_to_global_index(idx);
154        let idx = idx.as_();
155        assert_ptr(self.map_params.precondition() & idx_precondition, &self.data[idx])
156    }
157}
158
159// Mutable access is only allowed when ToScope is Thread,
160// indicating that each thread has a unique chunk.
161impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> IndexMut<Map::IndexType>
162    for GlobalGroupChunk<'a, T, CS, Map>
163where
164    CS: ChunkScope<ToScope = Thread>,
165{
166    #[inline(always)]
167    #[gpu_codegen::device]
168    fn index_mut(&mut self, idx: Map::IndexType) -> &mut T {
169        let (idx_precondition, idx) = self.map_params.local_to_global_index(idx);
170        let idx = idx.as_();
171        assert_ptr(self.map_params.precondition() & idx_precondition, &mut self.data[idx])
172    }
173}
174
175/// Chunk a global memory slice into unique chunks.
176/// supports mapping strategies for global memory.
177/// For example,
178/// - chunk_mut(&mut data, MapLinear::new(width))
179/// - chunk_mut(&mut data, Map2D::new(x_width))
180#[gpu_codegen::device]
181#[gpu_codegen::sync_data(0, 1, 2)]
182#[inline(always)]
183pub fn chunk_mut<'a, T, Map: ScopeUniqueMap<Grid2ThreadScope>>(
184    input: GpuGlobal<'a, [T]>,
185    map: Map,
186) -> GlobalThreadChunk<'a, T, Map> {
187    GlobalThreadChunk::new(input, map)
188}