gpu/
dim.rs

1pub struct DimX;
2pub struct DimY;
3pub struct DimZ;
4
5pub trait DimType {
6    const DIM_ID: u8;
7    const MLIR_DIM: &str;
8}
9
10pub(crate) enum DimTypeID {
11    X = 0,
12    Y = 1,
13    Z = 2,
14    Max = 3,
15}
16
17impl DimType for DimX {
18    const DIM_ID: u8 = DimTypeID::X as u8;
19    const MLIR_DIM: &'static str = "#gpu<dim x>";
20}
21
22impl DimType for DimY {
23    const DIM_ID: u8 = DimTypeID::Y as u8;
24    const MLIR_DIM: &'static str = "#gpu<dim y>";
25}
26
27impl DimType for DimZ {
28    const DIM_ID: u8 = DimTypeID::Z as u8;
29    const MLIR_DIM: &'static str = "#gpu<dim z>";
30}
31
32macro_rules! def_dim_fn {
33    ($name: ident, $priv_name: ident, $pub_name: ident, $(#[$meta:meta])*) => {
34        $(#[$meta])*
35        #[rustc_diagnostic_item = concat!("gpu::", stringify!($name))]
36        #[gpu_macros::device]
37        #[inline(never)]
38        const fn $priv_name() -> usize {
39            unimplemented!()
40        }
41
42        $(#[$meta])*
43        #[gpu_macros::device]
44        #[inline(always)]
45        #[expect(clippy::cast_possible_truncation)]
46        pub const fn $pub_name<D: DimType>() -> u32 {
47            crate::add_mlir_string_attr(D::MLIR_DIM);
48            $priv_name() as u32
49        }
50    };
51}
52
53def_dim_fn!(global_thread_id, _global_thread_id, global_id,);
54def_dim_fn!(thread_id, _thread_id, thread_id,);
55def_dim_fn!(block_id, _block_id, block_id,);
56def_dim_fn!(block_dim, _block_dim, block_dim, #[gpu_codegen::ret_sync_data(1000)]);
57def_dim_fn!(grid_dim, _grid_dim, grid_dim, #[gpu_codegen::ret_sync_data(1000)]);
58
59/// This is the hardware warp id indicating
60/// the warp within a SM.
61/// Thus, it is different from the subgroup id /warp id
62/// we usually use inside BlockTile.
63#[gpu_macros::device]
64#[inline(always)]
65#[allow(dead_code)]
66pub fn sm_warp_id() -> u32 {
67    let mut ret: u32;
68    unsafe {
69        crate::asm!("mov.u32 {0:reg32}, %warpid;", out(reg) ret);
70    }
71    ret
72}
73
74#[gpu_macros::device]
75#[inline(always)]
76#[allow(dead_code)]
77pub fn lane_id() -> u32 {
78    let mut laneid: u32;
79    unsafe {
80        crate::asm!("mov.u32 {0:reg32}, %laneid;", out(reg) laneid);
81    }
82    laneid
83}
84
85#[gpu_macros::device]
86#[gpu_codegen::ret_sync_data(1000)]
87#[inline(always)]
88pub const fn dim<D: DimType>() -> u32 {
89    block_dim::<D>() * grid_dim::<D>()
90}
91
92#[gpu_macros::device]
93#[gpu_codegen::ret_sync_data(1000)]
94#[inline(always)]
95pub fn block_size() -> u32 {
96    block_dim::<DimX>() * block_dim::<DimY>() * block_dim::<DimZ>()
97}
98
99#[gpu_macros::device]
100#[gpu_codegen::ret_sync_data(1000)]
101#[inline(always)]
102#[allow(dead_code)]
103pub fn num_blocks() -> u32 {
104    grid_dim::<DimX>() * grid_dim::<DimY>() * grid_dim::<DimZ>()
105}