gpu/
sync.rs

1//! Synchronization and atomic operations for GPU programming.
2
3use crate::{GpuGlobal, GpuShared};
4
5/// Synchronization within a thread block.
6/// Disallow divergent control flow to
7/// ensure all threads in a block can reach the sync point to avoid deadlock.
8#[inline(never)]
9#[gpu_codegen::device]
10#[rustc_diagnostic_item = "gpu::sync_threads"]
11#[gpu_codegen::sync_data]
12pub fn sync_threads() {
13    unimplemented!();
14}
15
16/// Define a type to represent the atomic RMW kind from
17/// [MLIR ArithBase](https://github.com/llvm/llvm-project/blob/llvmorg-20.1.8/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td)
18trait AtomicRMWKind {
19    const NAME: &str;
20}
21
22/// To prevent the user from mixing use of
23/// Atomic operation and chunk-based access,
24/// we wrap the reference to the data to be modified in an Atomic struct.
25/// ensuring that the user cannot access the data without using atomic operations.
26/// If user wants to repurpose the data for non-atomic access,
27/// they need to drop the Atomic struct first and need a sync across all blocks.
28/// For now, sync across all blocks is not supported, we should reject such code in analysis.
29/// TODO: Avoid repurposing data for atomic or chunk-based access.
30#[rustc_diagnostic_item = "gpu::sync::Atomic"]
31pub struct Atomic<'a, T: ?Sized> {
32    data: &'a T,
33}
34
35impl<'a, T> Atomic<'a, [T]> {
36    /// Get a reference to the data inside the Atomic struct.
37    /// This is unsafe because the user must ensure that no other thread
38    /// is accessing the data at the same time.
39    #[inline(always)]
40    #[gpu_codegen::device]
41    pub fn index(&'a self, i: usize) -> Atomic<'a, T> {
42        Atomic { data: &self.data[i] }
43    }
44}
45
46#[inline(never)]
47#[gpu_codegen::device]
48#[rustc_diagnostic_item = "gpu::atomic_rmw"]
49unsafe fn _atomic_rmw<T: num_traits::Num>(_mem: &T, _val: T, _kind: &'static str) -> T {
50    unimplemented!()
51}
52
53/// Atomic read-modify-write operation with kind defined in
54/// [MLIR memref::atomic_rmw](https://mlir.llvm.org/docs/Dialects/MemRef/#memrefatomic_rmw-memrefatomicrmwop)
55impl<'a, T: ?Sized> Atomic<'a, T> {
56    #[inline(never)]
57    #[gpu_codegen::device]
58    #[rustc_diagnostic_item = "gpu::sync::Atomic::new"]
59    pub fn new(data: GpuGlobal<'a, T>) -> Atomic<'a, T> {
60        Self { data: data.data }
61    }
62
63    #[inline(always)]
64    #[gpu_codegen::device]
65    #[expect(private_bounds)]
66    pub fn atomic_rmw<K: AtomicRMWKind>(&self, val: T) -> T
67    where
68        T: num_traits::Num,
69    {
70        unsafe { _atomic_rmw(self.data, val, K::NAME) }
71    }
72}
73
74pub struct SharedAtomic<'a, T: ?Sized> {
75    data: &'a GpuShared<T>,
76}
77
78/// Atomic read-modify-write operation with kind defined in
79/// [MLIR memref::atomic_rmw](https://mlir.llvm.org/docs/Dialects/MemRef/#memrefatomic_rmw-memrefatomicrmwop)
80impl<'a, T: ?Sized> SharedAtomic<'a, T> {
81    #[inline(never)]
82    #[gpu_codegen::device]
83    #[rustc_diagnostic_item = "gpu::sync::SharedAtomic::new"]
84    #[gpu_codegen::memspace_shared(1000)] // returned data is in shared memory
85    pub fn new(data: &'a mut GpuShared<T>) -> SharedAtomic<'a, T> {
86        Self { data }
87    }
88
89    #[inline(always)]
90    #[gpu_codegen::device]
91    #[expect(private_bounds)]
92    pub fn atomic_rmw<K: AtomicRMWKind>(&self, val: T) -> T
93    where
94        T: num_traits::Num,
95    {
96        unsafe { _atomic_rmw(self.data, val, K::NAME) }
97    }
98}
99
100impl<'a, T> SharedAtomic<'a, [T]> {
101    /// Get a reference to the data inside the Atomic struct.
102    /// This is unsafe because the user must ensure that no other thread
103    /// is accessing the data at the same time.
104    #[inline(always)]
105    #[gpu_codegen::device]
106    #[gpu_codegen::memspace_shared(1000)] // returned data is in shared memory
107    pub fn index(&'a self, i: usize) -> SharedAtomic<'a, T> {
108        SharedAtomic { data: &self.data[i] }
109    }
110}
111
112macro_rules! def_atomic_rmw_kind {
113    ($t:ident, $val:literal, $atomic_fn: ident, $trait:path) => {
114        #[doc = concat!("Atomic operation kind for [`Atomic<'a, T>::", stringify!($atomic_fn), "`]")]
115        pub struct $t;
116        impl AtomicRMWKind for $t {
117            const NAME: &str = concat!(stringify!($val), ": i64");
118        }
119
120        impl<'a, T: num_traits::Num> Atomic<'a, T> {
121            #[doc = concat!("Equivalent to: atomic_rmw::<[`Atomic<'a, T>::", stringify!($t), "`]>")]
122            #[inline(always)]
123            #[gpu_codegen::device]
124            pub fn $atomic_fn(&self, val: T) -> T
125            where
126                T: $trait,
127            {
128                self.atomic_rmw::<$t>(val)
129            }
130        }
131
132        impl<'a, T: num_traits::Num> SharedAtomic<'a, T> {
133            #[doc = concat!("Equivalent to: atomic_rmw::<[`SharedAtomic<'a, T>::", stringify!($t), "`]>")]
134            #[inline(always)]
135            #[gpu_codegen::device]
136            pub fn $atomic_fn(&self, val: T) -> T
137            where
138                T: $trait,
139            {
140                self.atomic_rmw::<$t>(val)
141            }
142        }
143    };
144}
145
146macro_rules! def_atomic_rmw_kinds {
147    ($($t:ident, $val:literal, $atomic_fn: ident, $trait:path);* $(;)?) => {
148        $(
149            def_atomic_rmw_kind!($t, $val, $atomic_fn, $trait);
150        )*
151    };
152}
153
154// Define AtomicRMWKind
155// mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
156def_atomic_rmw_kinds!(
157    AddF, 0, atomic_addf, num_traits::Float;
158    AddI, 1, atomic_addi, num_traits::PrimInt;
159    Assign, 2, atomic_assign, num_traits::Num;
160    MaximumF, 3, atomic_maximumf, num_traits::Float;
161    MaxS, 4, atomic_maxs, num_traits::Signed;
162    MaxU, 5, atomic_maxu, num_traits::Unsigned;
163    MinimumF, 6, atomic_minimumf, num_traits::Float;
164    MinS, 7, atomic_mins, num_traits::Signed;
165    MinU, 8, atomic_minu, num_traits::Unsigned;
166    MulF, 9, atomic_mulf, num_traits::Float;
167    MulI, 10, atomic_muli, num_traits::PrimInt;
168    OrI, 11, atomic_ori, num_traits::PrimInt;
169    AndI, 12, atomic_andi, num_traits::PrimInt;
170    MaxNumF, 13, atomic_maxnumf, num_traits::Float;
171    MinNumF, 14, atomic_minnumf, num_traits::Float;
172);