1pub struct DimX;
2pub struct DimY;
3pub struct DimZ;
4
5pub trait DimType {
6 const DIM_ID: u8;
7 const MLIR_DIM: &str;
8}
9
10pub(crate) enum DimTypeID {
11 X = 0,
12 Y = 1,
13 Z = 2,
14 Max = 3,
15}
16
17impl DimType for DimX {
18 const DIM_ID: u8 = DimTypeID::X as u8;
19 const MLIR_DIM: &'static str = "#gpu<dim x>";
20}
21
22impl DimType for DimY {
23 const DIM_ID: u8 = DimTypeID::Y as u8;
24 const MLIR_DIM: &'static str = "#gpu<dim y>";
25}
26
27impl DimType for DimZ {
28 const DIM_ID: u8 = DimTypeID::Z as u8;
29 const MLIR_DIM: &'static str = "#gpu<dim z>";
30}
31
32macro_rules! def_dim_fn {
33 ($name: ident, $priv_name: ident, $pub_name: ident, $(#[$meta:meta])*) => {
34 $(#[$meta])*
35 #[rustc_diagnostic_item = concat!("gpu::", stringify!($name))]
36 #[gpu_macros::device]
37 #[inline(never)]
38 const fn $priv_name() -> usize {
39 unimplemented!()
40 }
41
42 $(#[$meta])*
43 #[gpu_macros::device]
44 #[inline(always)]
45 #[expect(clippy::cast_possible_truncation)]
46 pub const fn $pub_name<D: DimType>() -> u32 {
47 crate::add_mlir_string_attr(D::MLIR_DIM);
48 $priv_name() as u32
49 }
50 };
51}
52
53def_dim_fn!(global_thread_id, _global_thread_id, global_id,);
54def_dim_fn!(thread_id, _thread_id, thread_id,);
55def_dim_fn!(block_id, _block_id, block_id,);
56def_dim_fn!(block_dim, _block_dim, block_dim, #[gpu_codegen::ret_sync_data(1000)]);
57def_dim_fn!(grid_dim, _grid_dim, grid_dim, #[gpu_codegen::ret_sync_data(1000)]);
58
59#[gpu_macros::device]
64#[inline(always)]
65#[allow(dead_code)]
66pub fn sm_warp_id() -> u32 {
67 let mut ret: u32;
68 unsafe {
69 crate::asm!("mov.u32 {0:reg32}, %warpid;", out(reg) ret);
70 }
71 ret
72}
73
74#[gpu_macros::device]
75#[inline(always)]
76#[allow(dead_code)]
77pub fn lane_id() -> u32 {
78 let mut laneid: u32;
79 unsafe {
80 crate::asm!("mov.u32 {0:reg32}, %laneid;", out(reg) laneid);
81 }
82 laneid
83}
84
85#[gpu_macros::device]
86#[gpu_codegen::ret_sync_data(1000)]
87#[inline(always)]
88pub const fn dim<D: DimType>() -> u32 {
89 block_dim::<D>() * grid_dim::<D>()
90}
91
92#[gpu_macros::device]
93#[gpu_codegen::ret_sync_data(1000)]
94#[inline(always)]
95pub fn block_size() -> u32 {
96 block_dim::<DimX>() * block_dim::<DimY>() * block_dim::<DimZ>()
97}
98
99#[gpu_macros::device]
100#[gpu_codegen::ret_sync_data(1000)]
101#[inline(always)]
102#[allow(dead_code)]
103pub fn num_blocks() -> u32 {
104 grid_dim::<DimX>() * grid_dim::<DimY>() * grid_dim::<DimZ>()
105}