parse_display_derive/
syn_utils.rs

1use proc_macro2::TokenStream;
2use quote::{quote, quote_spanned, ToTokens};
3use std::collections::HashSet;
4use syn::{
5    ext::IdentExt,
6    parse::discouraged::Speculative,
7    parse::Parse,
8    parse::ParseStream,
9    parse2, parse_str,
10    punctuated::Punctuated,
11    visit::{visit_path, Visit},
12    DeriveInput, GenericParam, Generics, Ident, LitStr, Path, Result, Token, Type, WherePredicate,
13};
14
15macro_rules! bail {
16    (_, $($arg:tt)*) => {
17        bail!(::proc_macro2::Span::call_site(), $($arg)*)
18    };
19    ($span:expr, $fmt:literal $(,)?) => {
20        return ::std::result::Result::Err(::syn::Error::new($span, ::std::format!($fmt)))
21    };
22    ($span:expr, $fmt:literal, $($arg:tt)*) => {
23        return ::std::result::Result::Err(::syn::Error::new($span, ::std::format!($fmt, $($arg)*)))
24    };
25}
26
27pub fn into_macro_output(input: Result<TokenStream>) -> proc_macro::TokenStream {
28    match input {
29        Ok(s) => s,
30        Err(e) => e.to_compile_error(),
31    }
32    .into()
33}
34
35pub struct GenericParamSet {
36    idents: HashSet<Ident>,
37}
38
39impl GenericParamSet {
40    pub fn new(generics: &Generics) -> Self {
41        let mut idents = HashSet::new();
42        for p in &generics.params {
43            match p {
44                GenericParam::Type(t) => {
45                    idents.insert(t.ident.unraw());
46                }
47                GenericParam::Const(t) => {
48                    idents.insert(t.ident.unraw());
49                }
50                GenericParam::Lifetime(_) => {}
51            }
52        }
53        Self { idents }
54    }
55    fn contains(&self, ident: &Ident) -> bool {
56        self.idents.contains(&ident.unraw())
57    }
58
59    pub fn contains_in_type(&self, ty: &Type) -> bool {
60        struct Visitor<'a> {
61            generics: &'a GenericParamSet,
62            result: bool,
63        }
64        impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
65            fn visit_path(&mut self, i: &'ast syn::Path) {
66                if i.leading_colon.is_none() {
67                    if let Some(s) = i.segments.iter().next() {
68                        if self.generics.contains(&s.ident) {
69                            self.result = true;
70                        }
71                    }
72                }
73                visit_path(self, i);
74            }
75        }
76        let mut visitor = Visitor {
77            generics: self,
78            result: false,
79        };
80        visitor.visit_type(ty);
81        visitor.result
82    }
83}
84
85pub enum Quotable<T> {
86    Direct(T),
87    Quoted { s: LitStr, args: ArgsOf<T> },
88}
89impl<T: Parse> Parse for Quotable<T> {
90    fn parse(input: ParseStream) -> Result<Self> {
91        let fork = input.fork();
92        if let Ok(s) = fork.parse::<LitStr>() {
93            input.advance_to(&fork);
94            let token: TokenStream = parse_str(&s.value())?;
95            let tokens = quote_spanned!(s.span()=> #token);
96            let args = parse2(tokens)?;
97            Ok(Quotable::Quoted { s, args })
98        } else {
99            Ok(Quotable::Direct(input.parse()?))
100        }
101    }
102}
103impl<T: ToTokens> ToTokens for Quotable<T> {
104    fn to_tokens(&self, tokens: &mut TokenStream) {
105        match self {
106            Self::Direct(value) => value.to_tokens(tokens),
107            Self::Quoted { s, .. } => s.to_tokens(tokens),
108        }
109    }
110}
111
112impl<T> Quotable<T> {
113    pub fn into_iter(self) -> impl IntoIterator<Item = T> {
114        match self {
115            Self::Direct(item) => vec![item],
116            Self::Quoted { args, .. } => args.into_iter().collect(),
117        }
118        .into_iter()
119    }
120}
121
122pub struct ArgsOf<T>(Punctuated<T, Token![,]>);
123
124impl<T: Parse> Parse for ArgsOf<T> {
125    fn parse(input: ParseStream) -> Result<Self> {
126        Ok(Self(Punctuated::parse_terminated(input)?))
127    }
128}
129impl<T: ToTokens> ToTokens for ArgsOf<T> {
130    fn to_tokens(&self, tokens: &mut TokenStream) {
131        self.0.to_tokens(tokens);
132    }
133}
134
135impl<T> ArgsOf<T> {
136    pub fn into_iter(self) -> impl Iterator<Item = T> {
137        self.0.into_iter()
138    }
139}
140
141pub fn impl_trait(
142    input: &DeriveInput,
143    trait_path: &Path,
144    wheres: &[WherePredicate],
145    contents: TokenStream,
146) -> TokenStream {
147    let ty = &input.ident;
148    let (impl_g, ty_g, where_clause) = input.generics.split_for_impl();
149    let mut wheres = wheres.to_vec();
150    if let Some(where_clause) = where_clause {
151        wheres.extend(where_clause.predicates.iter().cloned());
152    }
153    let where_clause = if wheres.is_empty() {
154        quote! {}
155    } else {
156        quote! { where #(#wheres,)*}
157    };
158    quote! {
159        #[automatically_derived]
160        impl #impl_g #trait_path for #ty #ty_g #where_clause {
161            #contents
162        }
163    }
164}
165pub fn impl_trait_result(
166    input: &DeriveInput,
167    trait_path: &Path,
168    wheres: &[WherePredicate],
169    contents: TokenStream,
170    dump: bool,
171) -> Result<TokenStream> {
172    let ts = impl_trait(input, trait_path, wheres, contents);
173    if dump {
174        panic!("macro output:\n{ts}");
175    }
176    Ok(ts)
177}