shed/fbinit/macros/expand.rs (148 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under both the MIT license found in the
* LICENSE-MIT file in the root directory of this source tree and the Apache
* License, Version 2.0 found in the LICENSE-APACHE file in the root directory
* of this source tree.
*/
use proc_macro2::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{parse_quote, Error, ItemFn, Result, Token};
#[derive(Copy, Clone, PartialEq)]
pub enum Mode {
Main,
Test,
}
mod kw {
syn::custom_keyword!(disable_fatal_signals);
syn::custom_keyword!(none);
syn::custom_keyword!(sigterm_only);
syn::custom_keyword!(all);
}
pub enum DisableFatalSignals {
Default(Token![default]),
None(kw::none),
SigtermOnly(kw::sigterm_only),
All(kw::all),
}
pub enum Arg {
DisableFatalSignals {
kw_token: kw::disable_fatal_signals,
eq_token: Token![=],
value: DisableFatalSignals,
},
}
impl Parse for Arg {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(kw::disable_fatal_signals) {
let kw_token = input.parse()?;
let eq_token = input.parse()?;
let lookahead = input.lookahead1();
let value = if lookahead.peek(kw::none) {
DisableFatalSignals::None(input.parse()?)
} else if lookahead.peek(Token![default]) {
DisableFatalSignals::Default(input.parse()?)
} else if lookahead.peek(kw::all) {
DisableFatalSignals::All(input.parse()?)
} else if lookahead.peek(kw::sigterm_only) {
DisableFatalSignals::SigtermOnly(input.parse()?)
} else {
return Err(lookahead.error());
};
Ok(Self::DisableFatalSignals {
kw_token,
eq_token,
value,
})
} else {
Err(lookahead.error())
}
}
}
pub fn expand(
mode: Mode,
args: Punctuated<Arg, Token![,]>,
mut function: ItemFn,
) -> Result<TokenStream> {
let mut disable_fatal_signals =
DisableFatalSignals::Default(syn::parse2(quote! { default }).expect("This always parses"));
for arg in args {
match arg {
Arg::DisableFatalSignals { value, .. } => disable_fatal_signals = value,
}
}
if function.sig.inputs.len() > 1 {
return Err(Error::new_spanned(
function.sig,
"expected one argument of type fbinit::FacebookInit",
));
}
if mode == Mode::Main && function.sig.ident != "main" {
return Err(Error::new_spanned(
function.sig,
"#[fbinit::main] must be used on the main function",
));
}
let guard = match mode {
Mode::Main => Some(quote! {
if module_path!().contains("::") {
panic!("fbinit must be performed in the crate root on the main function");
}
}),
Mode::Test => None,
};
let assignment = function.sig.inputs.first().map(|arg| quote!(let #arg =));
function.sig.inputs = Punctuated::new();
let block = function.block;
let body = match (function.sig.asyncness.is_some(), mode) {
(true, Mode::Test) => quote! {
fbinit_tokio::tokio_test(async #block )
},
(true, Mode::Main) => quote! {
fbinit_tokio::tokio_main(async #block )
},
(false, _) => {
let stmts = block.stmts;
quote! { #(#stmts)* }
}
};
let perform_init = match disable_fatal_signals {
DisableFatalSignals::Default(_) => {
// 8002 is 1 << 15 (SIGTERM) | 1 << 2 (SIGINT)
quote! {
fbinit::r#impl::perform_init_with_disable_signals(0x8002)
}
}
DisableFatalSignals::All(_) => {
// ffff is a mask of all 1's
quote! {
fbinit::r#impl::perform_init_with_disable_signals(0xffff)
}
}
DisableFatalSignals::SigtermOnly(_) => {
// 8000 is 1 << 15 (SIGTERM)
quote! {
fbinit::r#impl::perform_init_with_disable_signals(0x8000)
}
}
DisableFatalSignals::None(_) => {
quote! {
fbinit::perform_init()
}
}
};
function.block = parse_quote!({
#guard
#assignment unsafe {
#perform_init
};
let destroy_guard = unsafe { fbinit::r#impl::DestroyGuard::new() };
#body
});
function.sig.asyncness = None;
if mode == Mode::Test {
function.attrs.push(parse_quote!(#[test]));
}
Ok(quote!(#function))
}