You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
7.5 KiB
210 lines
7.5 KiB
//! Core dependency injection macros
|
|
|
|
extern crate proc_macro;
|
|
use proc_macro::TokenStream;
|
|
use quote::{format_ident, quote};
|
|
use syn::parse::{Parse, ParseStream, Result};
|
|
use syn::punctuated::Punctuated;
|
|
use syn::{
|
|
braced, parse, parse_macro_input, DeriveInput, Fields, FnArg, Ident, ItemFn, ItemStruct, Path,
|
|
Token, Type,
|
|
};
|
|
|
|
/// Defines a provider function, with generated helper that implicitly fetches argument instances from the registry
|
|
#[proc_macro_attribute]
|
|
pub fn provides(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
|
let function: ItemFn = parse(item).expect("can only be applied to functions");
|
|
|
|
// Create the info needed to refer to the function & the injected version we generate
|
|
let ident = function.sig.ident.clone();
|
|
let injected_ident = format_ident!("__gddi_{}_injected", ident);
|
|
|
|
// Create the info needed to generate the call to the original function
|
|
let inputs = function.sig.inputs.iter().map(|arg| {
|
|
if let FnArg::Typed(t) = arg {
|
|
return t.ty.clone();
|
|
}
|
|
panic!("can't be applied to struct methods");
|
|
});
|
|
let local_var_idents = (0..inputs.len()).map(|i| format_ident!("__input{}", i));
|
|
let local_var_idents_for_call = local_var_idents.clone();
|
|
|
|
let emitted_code = quote! {
|
|
// Injecting wrapper
|
|
fn #injected_ident(registry: std::sync::Arc<gddi::Registry>) -> std::pin::Pin<gddi::ProviderFutureBox> {
|
|
Box::pin(async move {
|
|
// Create a local variable for each argument, to ensure they get generated in a
|
|
// deterministic order (compiler complains otherwise)
|
|
#(let #local_var_idents = registry.get::<#inputs>().await;)*
|
|
|
|
// Actually call the original function
|
|
Box::new(#ident(#(#local_var_idents_for_call),*).await) as Box<dyn std::any::Any>
|
|
})
|
|
}
|
|
#function
|
|
};
|
|
emitted_code.into()
|
|
}
|
|
|
|
struct ModuleDef {
|
|
name: Ident,
|
|
providers: Punctuated<ProviderDef, Token![,]>,
|
|
submodules: Punctuated<Path, Token![,]>,
|
|
}
|
|
|
|
enum ModuleEntry {
|
|
Providers(Punctuated<ProviderDef, Token![,]>),
|
|
Submodules(Punctuated<Path, Token![,]>),
|
|
}
|
|
|
|
struct ProviderDef {
|
|
ty: Type,
|
|
ident: Ident,
|
|
parts: bool,
|
|
}
|
|
|
|
impl Parse for ModuleDef {
|
|
fn parse(input: ParseStream) -> Result<Self> {
|
|
// first thing is the module name followed by a comma
|
|
let name = input.parse()?;
|
|
input.parse::<Token![,]>()?;
|
|
// Then comes submodules or provider sections, in any order
|
|
let entries: Punctuated<ModuleEntry, Token![,]> = Punctuated::parse_terminated(input)?;
|
|
let mut providers = Punctuated::new();
|
|
let mut submodules = Punctuated::new();
|
|
for entry in entries.into_iter() {
|
|
match entry {
|
|
ModuleEntry::Providers(value) => {
|
|
if !providers.is_empty() {
|
|
panic!("providers specified more than once");
|
|
}
|
|
providers = value;
|
|
}
|
|
ModuleEntry::Submodules(value) => {
|
|
if !submodules.is_empty() {
|
|
panic!("submodules specified more than once");
|
|
}
|
|
submodules = value;
|
|
}
|
|
}
|
|
}
|
|
Ok(ModuleDef { name, providers, submodules })
|
|
}
|
|
}
|
|
|
|
impl Parse for ProviderDef {
|
|
fn parse(input: ParseStream) -> Result<Self> {
|
|
let parts = input.peek3(Token![=>]);
|
|
if parts {
|
|
match input.parse::<Ident>()?.to_string().as_str() {
|
|
"parts" => {}
|
|
keyword => panic!("expected 'parts', got '{}'", keyword),
|
|
}
|
|
}
|
|
|
|
// A provider definition follows this format: <Type> -> <function name>
|
|
let ty = input.parse()?;
|
|
input.parse::<Token![=>]>()?;
|
|
let ident = input.parse()?;
|
|
Ok(ProviderDef { ty, ident, parts })
|
|
}
|
|
}
|
|
|
|
impl Parse for ModuleEntry {
|
|
fn parse(input: ParseStream) -> Result<Self> {
|
|
match input.parse::<Ident>()?.to_string().as_str() {
|
|
"providers" => {
|
|
let entries;
|
|
braced!(entries in input);
|
|
Ok(ModuleEntry::Providers(entries.parse_terminated(ProviderDef::parse)?))
|
|
}
|
|
"submodules" => {
|
|
let entries;
|
|
braced!(entries in input);
|
|
Ok(ModuleEntry::Submodules(entries.parse_terminated(Path::parse)?))
|
|
}
|
|
keyword => {
|
|
panic!("unexpected keyword: {}", keyword);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Emits a module function that registers submodules & providers with the registry
|
|
#[proc_macro]
|
|
pub fn module(item: TokenStream) -> TokenStream {
|
|
let module = parse_macro_input!(item as ModuleDef);
|
|
let init_ident = module.name.clone();
|
|
let providers = module.providers.iter();
|
|
let types = providers.clone().map(|p| p.ty.clone());
|
|
let provider_idents =
|
|
providers.clone().map(|p| format_ident!("__gddi_{}_injected", p.ident.clone()));
|
|
let parting_functions = providers.filter_map(|p| match &p.ty {
|
|
Type::Path(ty) if p.parts => Some(format_ident!(
|
|
"__gddi_part_out_{}",
|
|
ty.path.get_ident().unwrap().to_string().to_lowercase()
|
|
)),
|
|
_ => None,
|
|
});
|
|
let submodule_idents = module.submodules.iter();
|
|
let emitted_code = quote! {
|
|
#[doc(hidden)]
|
|
#[allow(missing_docs)]
|
|
pub fn #init_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
|
|
// Register all providers on this module
|
|
let ret = builder#(.register_provider::<#types>(Box::new(#provider_idents)))*
|
|
// Register all submodules on this module
|
|
#(.register_module(#submodule_idents))*;
|
|
|
|
#(let ret = #parting_functions(ret);)*
|
|
|
|
ret
|
|
}
|
|
};
|
|
emitted_code.into()
|
|
}
|
|
|
|
/// Emits a default implementation for Stoppable that does nothing;
|
|
#[proc_macro_derive(Stoppable)]
|
|
pub fn derive_nop_stop(item: TokenStream) -> TokenStream {
|
|
let input = parse_macro_input!(item as DeriveInput);
|
|
let ident = input.ident;
|
|
let emitted_code = quote! {
|
|
impl gddi::Stoppable for #ident {}
|
|
};
|
|
emitted_code.into()
|
|
}
|
|
|
|
/// Generates the code necessary to split up a type into its components
|
|
#[proc_macro_attribute]
|
|
pub fn part_out(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
|
let struct_: ItemStruct = parse(item).expect("can only be applied to struct definitions");
|
|
let struct_ident = struct_.ident.clone();
|
|
let fields = match struct_.fields.clone() {
|
|
Fields::Named(f) => f,
|
|
_ => panic!("can only be applied to structs with named fields"),
|
|
}
|
|
.named;
|
|
|
|
let field_names = fields.iter().map(|f| f.ident.clone().expect("field without a name"));
|
|
let field_types = fields.iter().map(|f| f.ty.clone());
|
|
|
|
let fn_ident = format_ident!("__gddi_part_out_{}", struct_ident.to_string().to_lowercase());
|
|
|
|
let emitted_code = quote! {
|
|
#struct_
|
|
|
|
fn #fn_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
|
|
builder#(.register_provider::<#field_types>(Box::new(
|
|
|registry: std::sync::Arc<gddi::Registry>| -> std::pin::Pin<gddi::ProviderFutureBox> {
|
|
Box::pin(async move {
|
|
Box::new(async move {
|
|
registry.get::<#struct_ident>().await.#field_names
|
|
}.await) as Box<dyn std::any::Any>
|
|
})
|
|
})))*
|
|
}
|
|
};
|
|
emitted_code.into()
|
|
}
|