1use core::marker::PhantomData;
2
3use num_traits::AsPrimitive;
4
5pub use crate::cg::{Block, Grid, Thread, ThreadWarpTile};
6use crate::chunk::ScopeUniqueMap;
7use crate::dim::{
8 DimType, DimTypeID, DimX, DimY, DimZ, block_dim, block_id, block_size, dim, num_blocks,
9 thread_id,
10};
11use crate::grid_dim;
12
13trait PrivateTraitGuard {}
14
15trait SyncScope {}
16
17pub const TID_MAX_LEN: usize = 6;
22
23impl SyncScope for Thread {}
24
25impl<const SIZE: usize> SyncScope for ThreadWarpTile<SIZE> {}
26impl SyncScope for Block {}
27impl SyncScope for Grid {}
28
29#[expect(private_bounds)]
30pub trait BuildChunkScope<S2: SyncScope>: SyncScope {
31 type CS: ChunkScope<FromScope = Self, ToScope = S2>;
32 #[gpu_codegen::device]
33 #[gpu_codegen::ret_sync_data(1000)]
34 fn build_chunk_scope(&self, to: S2) -> Self::CS;
35}
36
37impl<const SIZE: usize> BuildChunkScope<ThreadWarpTile<SIZE>> for Block {
39 type CS = Block2WarpScope<SIZE>;
40 #[inline]
41 #[gpu_codegen::device]
42 fn build_chunk_scope(&self, _to: ThreadWarpTile<SIZE>) -> Block2WarpScope<SIZE> {
43 Block2WarpScope
44 }
45}
46
47impl BuildChunkScope<Thread> for Block {
49 type CS = Block2ThreadScope;
50 #[inline]
51 #[gpu_codegen::device]
52 fn build_chunk_scope(&self, _to: Thread) -> Block2ThreadScope {
53 Block2ThreadScope
54 }
55}
56
57impl BuildChunkScope<Block> for Grid {
59 type CS = Grid2BlockScope;
60 #[inline]
61 #[gpu_codegen::device]
62 fn build_chunk_scope(&self, _to: Block) -> Grid2BlockScope {
63 Grid2BlockScope
64 }
65}
66
67impl<const SIZE: usize> BuildChunkScope<ThreadWarpTile<SIZE>> for Grid {
69 type CS = Grid2WarpScope<SIZE>;
70 #[inline]
71 #[gpu_codegen::device]
72 fn build_chunk_scope(&self, _to: ThreadWarpTile<SIZE>) -> Grid2WarpScope<SIZE> {
73 Grid2WarpScope
74 }
75}
76
77impl BuildChunkScope<Thread> for Grid {
79 type CS = Grid2ThreadScope;
80 #[inline]
81 #[gpu_codegen::device]
82 fn build_chunk_scope(&self, _to: Thread) -> Grid2ThreadScope {
83 Grid2ThreadScope
84 }
85}
86
87impl<const SIZE: usize> BuildChunkScope<Thread> for ThreadWarpTile<SIZE> {
89 type CS = Warp2ThreadScope<SIZE>;
90 #[inline]
91 #[gpu_codegen::device]
92 fn build_chunk_scope(&self, _to: Thread) -> Warp2ThreadScope<SIZE> {
93 Warp2ThreadScope
94 }
95}
96
97#[inline]
98#[gpu_codegen::device]
99#[expect(private_bounds)]
100#[gpu_codegen::ret_sync_data(1000)]
101pub fn build_chunk_scope<S1, S2>(from: S1, to: S2) -> <S1 as BuildChunkScope<S2>>::CS
102where
103 S2: SyncScope,
104 S1: BuildChunkScope<S2> + SyncScope,
105{
106 from.build_chunk_scope(to)
107}
108
109#[expect(private_bounds)]
206pub trait ChunkScope: PrivateTraitGuard + Clone {
207 type FromScope: SyncScope;
208 type ToScope: SyncScope;
209
210 fn thread_ids() -> [u32; TID_MAX_LEN];
211
212 #[gpu_codegen::ret_sync_data(1000)]
213 fn global_dim<D: DimType>() -> u32;
214 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32;
215
216 #[inline]
218 #[gpu_codegen::device]
219 fn global_id_x(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
220 Self::global_id::<DimX>(thread_ids)
221 }
222
223 #[inline]
224 #[gpu_codegen::device]
225 fn global_id_y(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
226 Self::global_id::<DimY>(thread_ids)
227 }
228
229 #[inline]
230 #[gpu_codegen::device]
231 fn global_id_z(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
232 Self::global_id::<DimZ>(thread_ids)
233 }
234
235 #[inline]
236 #[gpu_codegen::device]
237 #[gpu_codegen::ret_sync_data(1000)]
238 fn global_dim_x() -> u32 {
239 Self::global_dim::<DimX>()
240 }
241
242 #[inline]
243 #[gpu_codegen::ret_sync_data(1000)]
244 #[gpu_codegen::device]
245 fn global_dim_y() -> u32 {
246 Self::global_dim::<DimY>()
247 }
248
249 #[inline]
250 #[gpu_codegen::ret_sync_data(1000)]
251 #[gpu_codegen::device]
252 fn global_dim_z() -> u32 {
253 Self::global_dim::<DimZ>()
254 }
255}
256
257#[derive(Copy, Clone)]
258pub struct Grid2ThreadScope;
259impl PrivateTraitGuard for Grid2ThreadScope {}
260impl ChunkScope for Grid2ThreadScope {
261 type FromScope = Grid;
262 type ToScope = Thread;
263
264 #[inline]
265 #[gpu_codegen::device]
266 fn thread_ids() -> [u32; TID_MAX_LEN] {
267 [
270 thread_id::<DimX>(),
271 thread_id::<DimY>(),
272 thread_id::<DimZ>(),
273 block_id::<DimX>(),
274 block_id::<DimY>(),
275 block_id::<DimZ>(),
276 ]
277 }
278
279 #[inline]
280 #[gpu_codegen::device]
281 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
282 thread_ids[(D::DIM_ID + DimTypeID::Max as u8) as usize] * block_dim::<D>()
283 + thread_ids[D::DIM_ID as usize]
284 }
285
286 #[inline]
287 #[gpu_codegen::device]
288 fn global_dim<D: DimType>() -> u32 {
289 dim::<D>()
290 }
291}
292
293#[derive(Copy, Clone)]
294pub struct Block2ThreadScope;
295impl PrivateTraitGuard for Block2ThreadScope {}
296impl ChunkScope for Block2ThreadScope {
297 type FromScope = Block;
298 type ToScope = Thread;
299
300 #[inline]
301 #[gpu_codegen::device]
302 fn thread_ids() -> [u32; TID_MAX_LEN] {
303 [thread_id::<DimX>(), thread_id::<DimY>(), thread_id::<DimZ>(), 0, 0, 0]
305 }
306
307 #[inline]
308 #[gpu_codegen::device]
309 fn global_dim<D: DimType>() -> u32 {
310 block_dim::<D>()
311 }
312
313 #[inline]
314 #[gpu_codegen::device]
315 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
316 thread_ids[D::DIM_ID as usize]
317 }
318}
319
320#[derive(Copy, Clone)]
321pub struct Grid2BlockScope;
322impl PrivateTraitGuard for Grid2BlockScope {}
323impl ChunkScope for Grid2BlockScope {
324 type FromScope = Grid;
325 type ToScope = Block;
326
327 #[inline]
328 #[gpu_codegen::device]
329 fn thread_ids() -> [u32; TID_MAX_LEN] {
330 [0, 0, 0, block_id::<DimX>(), block_id::<DimY>(), block_id::<DimZ>()]
331 }
332
333 #[inline]
334 #[gpu_codegen::device]
335 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
336 thread_ids[(D::DIM_ID + DimTypeID::Max as u8) as usize]
337 }
338
339 #[inline]
340 #[gpu_codegen::device]
341 fn global_dim<D: DimType>() -> u32 {
342 grid_dim::<D>()
343 }
344}
345
346#[derive(Copy, Clone)]
347pub struct Grid2WarpScope<const SIZE: usize>;
348
349impl<const SIZE: usize> PrivateTraitGuard for Grid2WarpScope<SIZE> {}
350
351impl<const SIZE: usize> Grid2WarpScope<SIZE> {
352 pub const CHECKED_SIZE: u32 = ThreadWarpTile::<SIZE>::CHECKED_SIZE;
353}
354
355impl<const SIZE: usize> ChunkScope for Grid2WarpScope<SIZE> {
356 type FromScope = Grid;
357 type ToScope = ThreadWarpTile<SIZE>;
358
359 #[inline]
360 #[gpu_codegen::device]
361 fn thread_ids() -> [u32; TID_MAX_LEN] {
362 [
366 0,
367 0,
368 Self::ToScope::_subgroup_id(),
369 block_id::<DimX>(),
370 block_id::<DimY>(),
371 block_id::<DimZ>(),
372 ]
373 }
374
375 #[inline]
376 #[gpu_codegen::device]
377 fn global_dim<D: DimType>() -> u32 {
378 if D::DIM_ID == 0 { (block_size() * num_blocks()) / Self::CHECKED_SIZE } else { 1 }
379 }
380
381 #[inline]
382 #[gpu_codegen::device]
383 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
384 if D::DIM_ID == 0 {
385 Grid2BlockScope::global_id::<D>(thread_ids) * Self::ToScope::_meta_group_size()
386 + thread_ids[2]
387 } else {
388 0
389 }
390 }
391}
392
393#[derive(Copy, Clone)]
394pub struct Block2WarpScope<const SIZE: usize>;
395
396impl<const SIZE: usize> PrivateTraitGuard for Block2WarpScope<SIZE> {}
397
398impl<const SIZE: usize> ChunkScope for Block2WarpScope<SIZE> {
399 type FromScope = Block;
400 type ToScope = ThreadWarpTile<SIZE>;
401
402 #[inline]
403 #[gpu_codegen::device]
404 fn thread_ids() -> [u32; TID_MAX_LEN] {
405 [0, 0, Self::ToScope::_subgroup_id(), 0, 0, 0]
407 }
408
409 #[inline]
410 #[gpu_codegen::device]
411 fn global_dim<D: DimType>() -> u32 {
412 if D::DIM_ID == 0 { Self::ToScope::_meta_group_size() } else { 1 }
413 }
414
415 #[inline]
416 #[gpu_codegen::device]
417 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
418 if D::DIM_ID == 0 { thread_ids[2] } else { 0 }
419 }
420}
421
422#[derive(Copy, Clone)]
423pub struct Warp2ThreadScope<const SIZE: usize>;
424
425impl<const SIZE: usize> PrivateTraitGuard for Warp2ThreadScope<SIZE> {}
426
427impl<const SIZE: usize> Warp2ThreadScope<SIZE> {
428 pub const CHECKED_SIZE: u32 = ThreadWarpTile::<SIZE>::CHECKED_SIZE;
429}
430
431impl<const SIZE: usize> ChunkScope for Warp2ThreadScope<SIZE> {
432 type FromScope = ThreadWarpTile<SIZE>;
433 type ToScope = Thread;
434
435 #[inline]
436 #[gpu_codegen::device]
437 fn thread_ids() -> [u32; TID_MAX_LEN] {
438 [0, Self::FromScope::_thread_rank(), 0, 0, 0, 0]
440 }
441
442 #[inline]
443 #[gpu_codegen::device]
444 fn global_dim<D: DimType>() -> u32 {
445 if D::DIM_ID == 0 { Self::CHECKED_SIZE } else { 1 }
446 }
447
448 #[inline]
449 #[gpu_codegen::device]
450 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
451 if D::DIM_ID == 0 { thread_ids[1] } else { 0 }
452 }
453}
454
455#[derive(Clone)]
456pub struct ChainedScope<CS1: ChunkScope, CS2: ChunkScope>
457where
458 CS2: ChunkScope<FromScope = CS1::ToScope>,
459{
460 _cs1: PhantomData<CS1>,
461 _cs2: PhantomData<CS2>,
462}
463
464impl<CS1: ChunkScope, CS2: ChunkScope> PrivateTraitGuard for ChainedScope<CS1, CS2> where
465 CS2: ChunkScope<FromScope = CS1::ToScope>
466{
467}
468impl<CS1: ChunkScope, CS2: ChunkScope> ChunkScope for ChainedScope<CS1, CS2>
469where
470 CS2: ChunkScope<FromScope = CS1::ToScope>,
471{
472 type FromScope = CS1::FromScope;
473
474 type ToScope = CS2::ToScope;
475
476 fn thread_ids() -> [u32; TID_MAX_LEN] {
477 let mut ids = CS1::thread_ids();
478 let ids2 = CS2::thread_ids();
479 for i in 0..TID_MAX_LEN {
480 ids[i] += ids2[i];
481 }
482 ids
483 }
484
485 fn global_dim<D: DimType>() -> u32 {
486 CS1::global_dim::<D>() * CS2::global_dim::<D>()
487 }
488
489 fn global_id<D: DimType>(thread_ids: [u32; TID_MAX_LEN]) -> u32 {
490 CS1::global_id::<D>(thread_ids) * CS2::global_dim::<D>() + CS2::global_id::<D>(thread_ids)
491 }
492}
493
494#[derive(Copy, Clone)]
495pub struct ChainedMap<
496 CS1: ChunkScope,
497 CS2: ChunkScope,
498 Map1: ScopeUniqueMap<CS1>,
499 Map2: ScopeUniqueMap<CS2>,
500> {
501 _cs1: PhantomData<CS1>,
502 _cs2: PhantomData<CS2>,
503 map1: Map1,
504 map2: Map2,
505}
506
507impl<CS1: ChunkScope, CS2: ChunkScope, Map1: ScopeUniqueMap<CS1>, Map2: ScopeUniqueMap<CS2>>
508 ChainedMap<CS1, CS2, Map1, Map2>
509where
510 CS2: ChunkScope<FromScope = CS1::ToScope>,
511{
512 pub fn new(m1: Map1, m2: Map2) -> Self {
513 Self { _cs1: PhantomData, _cs2: PhantomData, map1: m1, map2: m2 }
514 }
515}
516
517impl<CS1: ChunkScope, CS2: ChunkScope, Map1: ScopeUniqueMap<CS1>, Map2: ScopeUniqueMap<CS2>>
518 PrivateTraitGuard for ChainedMap<CS1, CS2, Map1, Map2>
519where
520 CS2: ChunkScope<FromScope = CS1::ToScope>,
521{
522}
523
524unsafe impl<CS1: ChunkScope, CS2: ChunkScope, Map1: ScopeUniqueMap<CS1>, Map2: ScopeUniqueMap<CS2>>
535 ScopeUniqueMap<ChainedScope<CS1, CS2>> for ChainedMap<CS1, CS2, Map1, Map2>
536where
537 CS2: ChunkScope<FromScope = CS1::ToScope>,
538 Map1: ScopeUniqueMap<CS1>,
539 Map2: ScopeUniqueMap<CS2>,
540 Map2::GlobalIndexType: AsPrimitive<Map1::IndexType>,
541{
542 type IndexType = Map2::IndexType;
543 type GlobalIndexType = Map1::GlobalIndexType;
544
545 fn map(
546 &self,
547 idx: Self::IndexType,
548 thread_ids: [u32; TID_MAX_LEN],
549 ) -> (bool, Self::GlobalIndexType) {
550 let (valid2, idx2) = self.map2.map(idx, thread_ids);
551 let (valid1, idx1) = self.map1.map(idx2.as_(), thread_ids);
552 (valid1 & valid2, idx1)
553 }
554}
555
556#[cfg(any(test, doctest))]
557pub mod test {
558 use super::*;
559 #[derive(Clone)]
560 pub struct MockBlock2WarpScope<const SIZE: usize, const WARP_ID: usize, const BLOCK_SIZE: usize>;
561 impl<const SIZE: usize, const WARP_ID: usize, const BLOCK_SIZE: usize> PrivateTraitGuard
562 for MockBlock2WarpScope<SIZE, WARP_ID, BLOCK_SIZE>
563 {
564 }
565 impl<const SIZE: usize, const WARP_ID: usize, const BLOCK_SIZE: usize> ChunkScope
566 for MockBlock2WarpScope<SIZE, WARP_ID, BLOCK_SIZE>
567 {
568 type FromScope = Block;
569 type ToScope = ThreadWarpTile<SIZE>;
570 fn thread_ids() -> [u32; TID_MAX_LEN] {
571 [0, 0, WARP_ID as u32, 0, 0, 0]
572 }
573 fn global_dim<D: DimType>() -> u32 {
574 if D::DIM_ID == 0 { (BLOCK_SIZE / SIZE) as u32 } else { 1 }
575 }
576 fn global_id<D: DimType>(ids: [u32; TID_MAX_LEN]) -> u32 {
577 if D::DIM_ID == 0 { ids[2] } else { 0 }
578 }
579 }
580
581 #[derive(Clone)]
582 pub struct MockWarp2ThreadScope<const SIZE: usize, const LANE_ID: u32>;
583 impl<const SIZE: usize, const LANE_ID: u32> PrivateTraitGuard
584 for MockWarp2ThreadScope<SIZE, LANE_ID>
585 {
586 }
587 impl<const SIZE: usize, const LANE_ID: u32> ChunkScope for MockWarp2ThreadScope<SIZE, LANE_ID> {
588 type FromScope = ThreadWarpTile<SIZE>;
589 type ToScope = Thread;
590 fn thread_ids() -> [u32; TID_MAX_LEN] {
591 [0, LANE_ID, 0, 0, 0, 0]
592 }
593 fn global_dim<D: DimType>() -> u32 {
594 if D::DIM_ID == 0 { SIZE as _ } else { 1 }
595 }
596 fn global_id<D: DimType>(ids: [u32; TID_MAX_LEN]) -> u32 {
597 if D::DIM_ID == 0 { ids[1] } else { 0 }
598 }
599 }
600
601 macro_rules! assert_map {
602 ($cs:ty, $m:expr, $idx:expr, $thread_ids:expr, $expected:expr) => {
603 let (valid, mapped_idx) = ScopeUniqueMap::<$cs>::map(&$m, $idx, $thread_ids);
604 assert!(
605 valid == $expected.0 && (mapped_idx == $expected.1 || !valid),
606 "idx = {}, mapped_idx = {}, valid = {} expected = {:?}",
607 $idx,
608 mapped_idx,
609 valid,
610 $expected
611 );
612 };
613 }
614
615 #[test]
616 fn test_mocked_scope() {
617 type S1 = MockBlock2WarpScope<32, 1, 128>;
618 type S2 = MockWarp2ThreadScope<32, 1>;
619 let thread_ids = [0, 1, 1, 0, 0, 0];
620 assert!(S1::global_dim_x() == 4, "dimx = {}", S1::global_dim_x());
621 assert!(S1::global_id_x(thread_ids) == 1, "id_x = {}", S1::global_id_x(thread_ids));
622 assert!(S1::global_dim_y() == 1, "dimy = {}", S1::global_dim_y());
623 assert!(S1::global_id_y(thread_ids) == 0, "id_y = {}", S1::global_id_y(thread_ids));
624 assert!(S1::global_dim_z() == 1, "dimz = {}", S1::global_dim_z());
625 assert!(S1::global_id_z(thread_ids) == 0, "id_z = {}", S1::global_id_z(thread_ids));
626 assert!(S2::global_dim_x() == 32, "dimx = {}", S2::global_dim_x());
627 assert!(S2::global_id_x(thread_ids) == 1, "id_x = {}", S2::global_id_x(thread_ids));
628 }
629
630 #[test]
631 fn test_chain_map() {
632 let map_warps = crate::MapLinear::new(64);
633 let map_warp_threads = crate::MapLinear::new(2);
634 type S1 = MockBlock2WarpScope<32, 1, 128>;
635 type S2 = MockWarp2ThreadScope<32, 1>;
636
637 let chained_map = ChainedMap::<S1, S2, _, _>::new(map_warps, map_warp_threads);
638
639 let thread_ids0 = [0, 0, 0, 0, 0, 0];
640 assert_map!(S2, map_warp_threads, 0, thread_ids0, (true, 0));
641 assert_map!(S1, map_warps, 0, thread_ids0, (true, 0));
642 assert_map!(_, chained_map.clone(), 0, thread_ids0, (true, 0));
643
644 assert_map!(S2, map_warp_threads, 1, thread_ids0, (true, 1));
645 assert_map!(S1, map_warps, 1, thread_ids0, (true, 1));
646 assert_map!(_, chained_map.clone(), 1, thread_ids0, (true, 1));
647
648 assert_map!(S2, map_warp_threads, 2, thread_ids0, (true, 64));
649
650 let thread_ids = [0, 0, 1, 0, 0, 0];
651 assert_map!(S2, map_warp_threads, 0, thread_ids, (true, 0));
652 assert_map!(S1, map_warps, 0, thread_ids, (true, 64));
653 assert_map!(_, chained_map.clone(), 0, thread_ids, (true, 64));
654
655 assert_map!(S2, map_warp_threads, 1, thread_ids, (true, 1));
656 assert_map!(S1, map_warps, 1, thread_ids, (true, 65));
657 assert_map!(_, chained_map.clone(), 1, thread_ids, (true, 65));
658
659 let thread_ids = [0, 1, 1, 0, 0, 0];
660 assert_map!(S2, map_warp_threads, 0, thread_ids, (true, 2));
661 assert_map!(S1, map_warps, 2, thread_ids, (true, 66));
662 assert_map!(_, chained_map.clone(), 0, thread_ids, (true, 66));
663 }
664
665 #[test]
666 fn test_chain_map_lane_0_only() {
667 const BLOCK_SIZE: usize = 128;
672 const WIDTH: usize = 64;
673 const WARP_SIZE: usize = 32;
674 const N: usize = BLOCK_SIZE / WARP_SIZE * WIDTH;
675 let map_warps = crate::MapLinear::new(WIDTH);
676 let map_warp_threads = crate::MapLinear::new(WIDTH);
677 type S1 = MockBlock2WarpScope<WARP_SIZE, 1, BLOCK_SIZE>;
678 type S2 = MockWarp2ThreadScope<WARP_SIZE, 1>;
679
680 let chained_map = ChainedMap::<S1, S2, _, _>::new(map_warps, map_warp_threads);
681
682 let thread_ids0 = [0, 0, 0, 0, 0, 0];
683 assert_map!(S2, map_warp_threads, 0, thread_ids0, (true, 0));
684 assert_map!(S1, map_warps, 0, thread_ids0, (true, 0));
685 assert_map!(_, chained_map.clone(), 0, thread_ids0, (true, 0));
686
687 assert_map!(S2, map_warp_threads, 1, thread_ids0, (true, 1));
688 assert_map!(S1, map_warps, 1, thread_ids0, (true, 1));
689 assert_map!(_, chained_map.clone(), 1, thread_ids0, (true, 1));
690
691 let thread_ids = [0, 1, 1, 0, 0, 0];
692 assert_map!(S2, map_warp_threads, 0, thread_ids, (true, WIDTH));
693 assert_map!(S1, map_warps, 64, thread_ids, (true, N + WIDTH)); assert_map!(_, chained_map.clone(), 0, thread_ids, (true, N + WIDTH));
695 }
696}