gpu_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::ToTokens;
5mod arch;
6mod gpu_syntax;
7mod host_rewriter;
8mod reshape_map;
9
10#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
11pub(crate) enum CodegenTarget {
12    Cpu,
13    Gpu,
14    GpuClippy,
15}
16
17impl CodegenTarget {
18    fn is_gpu_only(&self) -> bool {
19        matches!(self, CodegenTarget::Gpu)
20    }
21
22    fn need_register_tool(&self) -> bool {
23        matches!(self, CodegenTarget::Gpu)
24    }
25}
26
27fn target() -> CodegenTarget {
28    let target = std::env::var("__CODEGEN_TARGET__").unwrap_or_else(|_| "CPU".into());
29    match target.as_str() {
30        "CPU" => CodegenTarget::Cpu,
31        "GPU" => CodegenTarget::Gpu,
32        "GPU-CLIPPY" => CodegenTarget::GpuClippy,
33        _ => panic!("Unexpected __CODEGEN_TARGET__: {}", target),
34    }
35}
36
37#[proc_macro_attribute]
38pub fn kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
39    gpu_syntax::rewrite_gpu_code(attr, item, true, target())
40}
41
42/// This attribute generates a host wrapper around a kernel function, allowing it to be launched from the host.
43/// The kernel function itself is original function with Config.
44/// The generated host function is in mod #kname {pub fn launch(...)}
45#[proc_macro_attribute]
46pub fn cuda_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
47    host_rewriter::create_host_from_kernel(attr, item, target())
48}
49
50#[proc_macro_attribute]
51pub fn host(attr: TokenStream, item: TokenStream) -> TokenStream {
52    host_rewriter::rewrite(attr, item, target())
53}
54
55#[proc_macro_attribute]
56pub fn device(attr: TokenStream, item: TokenStream) -> TokenStream {
57    gpu_syntax::rewrite_gpu_code(attr, item, false, target())
58}
59
60/// Add gpu attributes to the kernel function
61/// e.g.
62/// #[gpu::attr(nvvm_launch_bound(256, 1, 1, 2))]
63#[proc_macro_attribute]
64pub fn attr(attr: TokenStream, item: TokenStream) -> TokenStream {
65    let mut kfun = syn::parse_macro_input!(item as syn::ItemFn);
66    let attr: proc_macro2::TokenStream = attr.into();
67    if target().need_register_tool() {
68        kfun.attrs.push(syn::parse_quote!(#[gpu_codegen::#attr]));
69    }
70    kfun.into_token_stream().into()
71}
72
73#[proc_macro]
74pub fn reshape_map_macro(input: TokenStream) -> TokenStream {
75    reshape_map::map_reshape_params(input)
76}
77
78#[proc_macro]
79pub fn nvptx_to_target_asm(input: TokenStream) -> TokenStream {
80    arch::replace_asm(input)
81}