use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{
braced,
parse::{End, Parse},
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
};
use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
pub(crate) struct Initializer {
attrs: Vec<InitializerAttribute>,
this: Option<This>,
path: Path,
brace_token: token::Brace,
fields: Punctuated<InitializerField, Token![,]>,
rest: Option<(Token![..], Expr)>,
error: Option<(Token![?], Type)>,
}
struct This {
_and_token: Token![&],
ident: Ident,
_in_token: Token![in],
}
struct InitializerField {
attrs: Vec<Attribute>,
kind: InitializerKind,
}
enum InitializerKind {
Value {
ident: Ident,
value: Option<(Token![:], Expr)>,
},
Init {
ident: Ident,
_left_arrow_token: Token![<-],
value: Expr,
},
Code {
_underscore_token: Token![_],
_colon_token: Token![:],
block: Block,
},
}
impl InitializerKind {
fn ident(&self) -> Option<&Ident> {
match self {
Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
Self::Code { .. } => None,
}
}
}
enum InitializerAttribute {
DefaultError(DefaultErrorAttribute),
}
struct DefaultErrorAttribute {
ty: Box<Type>,
}
pub(crate) fn expand(
Initializer {
attrs,
this,
path,
brace_token,
fields,
rest,
error,
}: Initializer,
default_error: Option<&'static str>,
pinned: bool,
dcx: &mut DiagCtxt,
) -> Result<TokenStream, ErrorGuaranteed> {
let error = error.map_or_else(
|| {
if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
#[expect(irrefutable_let_patterns)]
if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
Some(ty.clone())
} else {
acc
}
}) {
default_error
} else if let Some(default_error) = default_error {
syn::parse_str(default_error).unwrap()
} else {
dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
parse_quote!(::core::convert::Infallible)
}
},
|(_, err)| Box::new(err),
);
let slot = format_ident!("slot");
let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
(
format_ident!("HasPinData"),
format_ident!("PinData"),
format_ident!("__pin_data"),
format_ident!("pin_init_from_closure"),
)
} else {
(
format_ident!("HasInitData"),
format_ident!("InitData"),
format_ident!("__init_data"),
format_ident!("init_from_closure"),
)
};
let init_kind = get_init_kind(rest, dcx);
let zeroable_check = match init_kind {
InitKind::Normal => quote!(),
InitKind::Zeroing => quote! {
fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
where T: ::pin_init::Zeroable
{}
assert_zeroable(#slot);
unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
},
};
let this = match this {
None => quote!(),
Some(This { ident, .. }) => quote! {
let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
},
};
let data = Ident::new("__data", Span::mixed_site());
let init_fields = init_fields(&fields, pinned, &data, &slot);
let field_check = make_field_check(&fields, init_kind, &path);
Ok(quote! {{
let #data = unsafe {
use ::pin_init::__internal::#has_data_trait;
#path::#get_data()
};
let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
#data,
move |slot| {
#zeroable_check
#this
#init_fields
#field_check
Ok(unsafe { ::pin_init::__internal::InitOk::new() })
}
);
let init = move |slot| -> ::core::result::Result<(), #error> {
init(slot).map(|__InitOk| ())
};
let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
init
}})
}
enum InitKind {
Normal,
Zeroing,
}
fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
let Some((dotdot, expr)) = rest else {
return InitKind::Normal;
};
match &expr {
Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
Expr::Path(ExprPath {
attrs,
qself: None,
path:
Path {
leading_colon: None,
segments,
},
}) if attrs.is_empty()
&& segments.len() == 2
&& segments[0].ident == "Zeroable"
&& segments[0].arguments.is_none()
&& segments[1].ident == "init_zeroed"
&& segments[1].arguments.is_none() =>
{
return InitKind::Zeroing;
}
_ => {}
},
_ => {}
}
dcx.error(
dotdot.span().join(expr.span()).unwrap_or(expr.span()),
"expected nothing or `..Zeroable::init_zeroed()`.",
);
InitKind::Normal
}
fn init_fields(
fields: &Punctuated<InitializerField, Token![,]>,
pinned: bool,
data: &Ident,
slot: &Ident,
) -> TokenStream {
let mut guards = vec![];
let mut guard_attrs = vec![];
let mut res = TokenStream::new();
for InitializerField { attrs, kind } in fields {
let cfgs = {
let mut cfgs = attrs.clone();
cfgs.retain(|attr| attr.path().is_ident("cfg"));
cfgs
};
let init = match kind {
InitializerKind::Value { ident, value } => {
let mut value_ident = ident.clone();
let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
value_ident.set_span(value.span());
quote!(let #value_ident = #value;)
});
let write = quote_spanned!(ident.span()=> ::core::ptr::write);
let accessor = if pinned {
let project_ident = format_ident!("__project_{ident}");
quote! {
unsafe { #data.#project_ident(&mut (*#slot).#ident) }
}
} else {
quote! {
unsafe { &mut (*#slot).#ident }
}
};
quote! {
#(#attrs)*
{
#value_prep
unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
}
#(#cfgs)*
#[allow(unused_variables)]
let #ident = #accessor;
}
}
InitializerKind::Init { ident, value, .. } => {
let init = format_ident!("init", span = value.span());
let (value_init, accessor) = if pinned {
let project_ident = format_ident!("__project_{ident}");
(
quote! {
unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
},
quote! {
unsafe { #data.#project_ident(&mut (*#slot).#ident) }
},
)
} else {
(
quote! {
unsafe {
::pin_init::Init::__init(
#init,
::core::ptr::addr_of_mut!((*#slot).#ident),
)?
};
},
quote! {
unsafe { &mut (*#slot).#ident }
},
)
};
quote! {
#(#attrs)*
{
let #init = #value;
#value_init
}
#(#cfgs)*
#[allow(unused_variables)]
let #ident = #accessor;
}
}
InitializerKind::Code { block: value, .. } => quote! {
#(#attrs)*
#[allow(unused_braces)]
#value
},
};
res.extend(init);
if let Some(ident) = kind.ident() {
let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
res.extend(quote! {
#(#cfgs)*
let #guard = unsafe {
::pin_init::__internal::DropGuard::new(
::core::ptr::addr_of_mut!((*slot).#ident)
)
};
});
guards.push(guard);
guard_attrs.push(cfgs);
}
}
quote! {
#res
#(
#(#guard_attrs)*
::core::mem::forget(#guards);
)*
}
}
fn make_field_check(
fields: &Punctuated<InitializerField, Token![,]>,
init_kind: InitKind,
path: &Path,
) -> TokenStream {
let field_attrs = fields
.iter()
.filter_map(|f| f.kind.ident().map(|_| &f.attrs));
let field_name = fields.iter().filter_map(|f| f.kind.ident());
match init_kind {
InitKind::Normal => quote! {
#[allow(unreachable_code, clippy::diverging_sub_expression)]
let _ = || unsafe {
::core::ptr::write(slot, #path {
#(
#(#field_attrs)*
#field_name: ::core::panic!(),
)*
})
};
},
InitKind::Zeroing => quote! {
#[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
let _ = || unsafe {
::core::ptr::write(slot, #path {
#(
#(#field_attrs)*
#field_name: ::core::panic!(),
)*
..::core::mem::zeroed()
})
};
},
}
}
impl Parse for Initializer {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
let path = input.parse()?;
let content;
let brace_token = braced!(content in input);
let mut fields = Punctuated::new();
loop {
let lh = content.lookahead1();
if lh.peek(End) || lh.peek(Token![..]) {
break;
} else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
fields.push_value(content.parse()?);
let lh = content.lookahead1();
if lh.peek(End) {
break;
} else if lh.peek(Token![,]) {
fields.push_punct(content.parse()?);
} else {
return Err(lh.error());
}
} else {
return Err(lh.error());
}
}
let rest = content
.peek(Token![..])
.then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
.transpose()?;
let error = input
.peek(Token![?])
.then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
.transpose()?;
let attrs = attrs
.into_iter()
.map(|a| {
if a.path().is_ident("default_error") {
a.parse_args::<DefaultErrorAttribute>()
.map(InitializerAttribute::DefaultError)
} else {
Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
}
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Self {
attrs,
this,
path,
brace_token,
fields,
rest,
error,
})
}
}
impl Parse for DefaultErrorAttribute {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Self { ty: input.parse()? })
}
}
impl Parse for This {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
Ok(Self {
_and_token: input.parse()?,
ident: input.parse()?,
_in_token: input.parse()?,
})
}
}
impl Parse for InitializerField {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
Ok(Self {
attrs,
kind: input.parse()?,
})
}
}
impl Parse for InitializerKind {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let lh = input.lookahead1();
if lh.peek(Token![_]) {
Ok(Self::Code {
_underscore_token: input.parse()?,
_colon_token: input.parse()?,
block: input.parse()?,
})
} else if lh.peek(Ident) {
let ident = input.parse()?;
let lh = input.lookahead1();
if lh.peek(Token![<-]) {
Ok(Self::Init {
ident,
_left_arrow_token: input.parse()?,
value: input.parse()?,
})
} else if lh.peek(Token![:]) {
Ok(Self::Value {
ident,
value: Some((input.parse()?, input.parse()?)),
})
} else if lh.peek(Token![,]) || lh.peek(End) {
Ok(Self::Value { ident, value: None })
} else {
Err(lh.error())
}
} else {
Err(lh.error())
}
}
}