1use core::marker::PhantomData;
2use core::ops::{Index, IndexMut};
3
4use num_traits::AsPrimitive;
5
6use crate::chunk_scope::{
7 ChainedMap, ChainedScope, ChunkScope, Grid, Grid2ThreadScope, TID_MAX_LEN, Thread,
8};
9use crate::{GpuGlobal, assert_ptr};
10
11pub unsafe trait ScopeUniqueMap<CS: ChunkScope>: Clone {
27 type IndexType: Copy + 'static;
28 type GlobalIndexType: AsPrimitive<usize>;
29 #[inline]
30 #[gpu_codegen::device]
31 fn precondition(&self) -> bool {
32 true
33 }
34
35 #[gpu_codegen::device]
39 fn map(
40 &self,
41 idx: Self::IndexType,
42 thread_ids: [u32; TID_MAX_LEN],
43 ) -> (bool, Self::GlobalIndexType);
44}
45
46pub(crate) trait ScopeUniqueMapProvidedMethods<CS: ChunkScope>: ScopeUniqueMap<CS> {
49 #[inline]
50 #[gpu_codegen::device]
51 fn local_to_global_index(&self, idx: Self::IndexType) -> (bool, Self::GlobalIndexType) {
52 self.map(idx, CS::thread_ids())
53 }
54}
55
56impl<T: ScopeUniqueMap<CS>, CS: ChunkScope> ScopeUniqueMapProvidedMethods<CS> for T {}
57
58pub struct GlobalGroupChunk<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> {
65 data: &'a mut [T], pub map_params: Map,
67 dummy: PhantomData<CS>,
68}
69
70pub type GlobalThreadChunk<'a, T, Map> = GlobalGroupChunk<'a, T, Grid2ThreadScope, Map>;
71
72impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> GlobalGroupChunk<'a, T, CS, Map>
74where
75 CS: ChunkScope<FromScope = Grid>,
76{
77 #[inline]
78 #[rustc_diagnostic_item = "gpu::chunk_mut"]
79 #[gpu_codegen::device]
80 #[gpu_codegen::sync_data(0, 1)]
81 pub fn new(global: GpuGlobal<'a, [T]>, map_params: Map) -> Self {
84 if !map_params.precondition() {
85 core::intrinsics::abort();
86 }
87 Self { data: global.data, map_params, dummy: PhantomData }
88 }
89
90 #[cfg(not(feature = "codegen_tests"))]
97 pub fn new_from_host<'b: 'a>(
98 slice: &'a mut cuda_bindings::TensorViewMut<'b, [T]>,
99 map_params: Map,
100 ) -> Self {
101 unsafe {
102 Self {
103 data: &mut *(slice.as_flat_devptr() as *mut [T]),
104 map_params,
105 dummy: PhantomData,
106 }
107 }
108 }
109}
110
111impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> GlobalGroupChunk<'a, T, CS, Map> {
113 #[gpu_codegen::device]
116 #[gpu_codegen::sync_data(2)]
117 pub fn chunk_to_scope<CS2: ChunkScope, Map2: ScopeUniqueMap<CS2>>(
118 self,
119 _cs: CS2,
120 map: Map2,
121 ) -> GlobalGroupChunk<'a, T, ChainedScope<CS, CS2>, ChainedMap<CS, CS2, Map, Map2>>
122 where
123 Map: ScopeUniqueMap<CS>,
124 CS: ChunkScope<ToScope = CS2::FromScope>,
125 Map2::GlobalIndexType: AsPrimitive<Map::IndexType>,
126 {
127 GlobalGroupChunk {
128 data: self.data,
129 map_params: ChainedMap::new(self.map_params, map),
130 dummy: PhantomData,
131 }
132 }
133
134 #[gpu_codegen::device]
135 #[inline]
136 pub fn local2global(
137 &self,
138 idx: <Map as ScopeUniqueMap<CS>>::IndexType,
139 ) -> Map::GlobalIndexType {
140 self.map_params.local_to_global_index(idx).1
141 }
142}
143
144impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> Index<Map::IndexType>
146 for GlobalGroupChunk<'a, T, CS, Map>
147{
148 type Output = T;
149
150 #[inline(always)]
151 #[gpu_codegen::device]
152 fn index(&self, idx: Map::IndexType) -> &T {
153 let (idx_precondition, idx) = self.map_params.local_to_global_index(idx);
154 let idx = idx.as_();
155 assert_ptr(self.map_params.precondition() & idx_precondition, &self.data[idx])
156 }
157}
158
159impl<'a, T, CS: ChunkScope, Map: ScopeUniqueMap<CS>> IndexMut<Map::IndexType>
162 for GlobalGroupChunk<'a, T, CS, Map>
163where
164 CS: ChunkScope<ToScope = Thread>,
165{
166 #[inline(always)]
167 #[gpu_codegen::device]
168 fn index_mut(&mut self, idx: Map::IndexType) -> &mut T {
169 let (idx_precondition, idx) = self.map_params.local_to_global_index(idx);
170 let idx = idx.as_();
171 assert_ptr(self.map_params.precondition() & idx_precondition, &mut self.data[idx])
172 }
173}
174
175#[gpu_codegen::device]
181#[gpu_codegen::sync_data(0, 1, 2)]
182#[inline(always)]
183pub fn chunk_mut<'a, T, Map: ScopeUniqueMap<Grid2ThreadScope>>(
184 input: GpuGlobal<'a, [T]>,
185 map: Map,
186) -> GlobalThreadChunk<'a, T, Map> {
187 GlobalThreadChunk::new(input, map)
188}