1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::ToTokens;
5mod arch;
6mod gpu_syntax;
7mod host_rewriter;
8mod reshape_map;
9
10#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
11pub(crate) enum CodegenTarget {
12 Cpu,
13 Gpu,
14 GpuClippy,
15}
16
17impl CodegenTarget {
18 fn is_gpu_only(&self) -> bool {
19 matches!(self, CodegenTarget::Gpu)
20 }
21
22 fn need_register_tool(&self) -> bool {
23 matches!(self, CodegenTarget::Gpu)
24 }
25}
26
27fn target() -> CodegenTarget {
28 let target = std::env::var("__CODEGEN_TARGET__").unwrap_or_else(|_| "CPU".into());
29 match target.as_str() {
30 "CPU" => CodegenTarget::Cpu,
31 "GPU" => CodegenTarget::Gpu,
32 "GPU-CLIPPY" => CodegenTarget::GpuClippy,
33 _ => panic!("Unexpected __CODEGEN_TARGET__: {}", target),
34 }
35}
36
37#[proc_macro_attribute]
38pub fn kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
39 gpu_syntax::rewrite_gpu_code(attr, item, true, target())
40}
41
42#[proc_macro_attribute]
46pub fn cuda_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
47 host_rewriter::create_host_from_kernel(attr, item, target())
48}
49
50#[proc_macro_attribute]
51pub fn host(attr: TokenStream, item: TokenStream) -> TokenStream {
52 host_rewriter::rewrite(attr, item, target())
53}
54
55#[proc_macro_attribute]
56pub fn device(attr: TokenStream, item: TokenStream) -> TokenStream {
57 gpu_syntax::rewrite_gpu_code(attr, item, false, target())
58}
59
60#[proc_macro_attribute]
64pub fn attr(attr: TokenStream, item: TokenStream) -> TokenStream {
65 let mut kfun = syn::parse_macro_input!(item as syn::ItemFn);
66 let attr: proc_macro2::TokenStream = attr.into();
67 if target().need_register_tool() {
68 kfun.attrs.push(syn::parse_quote!(#[gpu_codegen::#attr]));
69 }
70 kfun.into_token_stream().into()
71}
72
73#[proc_macro]
74pub fn reshape_map_macro(input: TokenStream) -> TokenStream {
75 reshape_map::map_reshape_params(input)
76}
77
78#[proc_macro]
79pub fn nvptx_to_target_asm(input: TokenStream) -> TokenStream {
80 arch::replace_asm(input)
81}