Path: blob/main/crates/component-macro/src/component.rs
1692 views
use proc_macro2::{Span, TokenStream};1use quote::{format_ident, quote};2use std::collections::HashSet;3use std::fmt;4use syn::parse::{Parse, ParseStream};5use syn::punctuated::Punctuated;6use syn::{Data, DeriveInput, Error, Ident, Result, Token, braced, parse_quote};7use wasmtime_component_util::{DiscriminantSize, FlagsSize};89mod kw {10syn::custom_keyword!(record);11syn::custom_keyword!(variant);12syn::custom_keyword!(flags);13syn::custom_keyword!(name);14syn::custom_keyword!(wasmtime_crate);15}1617#[derive(Debug, Copy, Clone)]18enum Style {19Record,20Enum,21Variant,22}2324impl fmt::Display for Style {25fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {26match self {27Style::Record => f.write_str("record"),28Style::Enum => f.write_str("enum"),29Style::Variant => f.write_str("variant"),30}31}32}3334#[derive(Debug, Clone)]35enum ComponentAttr {36Style(Style),37WasmtimeCrate(syn::Path),38}3940impl Parse for ComponentAttr {41fn parse(input: ParseStream) -> Result<Self> {42let lookahead = input.lookahead1();43if lookahead.peek(kw::record) {44input.parse::<kw::record>()?;45Ok(ComponentAttr::Style(Style::Record))46} else if lookahead.peek(kw::variant) {47input.parse::<kw::variant>()?;48Ok(ComponentAttr::Style(Style::Variant))49} else if lookahead.peek(Token![enum]) {50input.parse::<Token![enum]>()?;51Ok(ComponentAttr::Style(Style::Enum))52} else if lookahead.peek(kw::wasmtime_crate) {53input.parse::<kw::wasmtime_crate>()?;54input.parse::<Token![=]>()?;55Ok(ComponentAttr::WasmtimeCrate(input.parse()?))56} else if input.peek(kw::flags) {57Err(input.error(58"`flags` not allowed here; \59use `wasmtime::component::flags!` macro to define `flags` types",60))61} else {62Err(lookahead.error())63}64}65}6667fn find_rename(attributes: &[syn::Attribute]) -> Result<Option<syn::LitStr>> {68let mut name = None;6970for attribute in attributes {71if !attribute.path().is_ident("component") {72continue;73}74let name_literal = attribute.parse_args_with(|parser: ParseStream<'_>| {75parser.parse::<kw::name>()?;76parser.parse::<Token![=]>()?;77parser.parse::<syn::LitStr>()78})?;7980if name.is_some() {81return Err(Error::new_spanned(82attribute,83"duplicate field rename attribute",84));85}8687name = Some(name_literal);88}8990Ok(name)91}9293fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn::Generics {94let mut generics = generics.clone();95for param in &mut generics.params {96if let syn::GenericParam::Type(ref mut type_param) = *param {97type_param.bounds.push(bound.clone());98}99}100generics101}102103pub struct VariantCase<'a> {104attrs: &'a [syn::Attribute],105ident: &'a syn::Ident,106ty: Option<&'a syn::Type>,107}108109pub trait Expander {110fn expand_record(111&self,112name: &syn::Ident,113generics: &syn::Generics,114fields: &[&syn::Field],115wasmtime_crate: &syn::Path,116) -> Result<TokenStream>;117118fn expand_variant(119&self,120name: &syn::Ident,121generics: &syn::Generics,122discriminant_size: DiscriminantSize,123cases: &[VariantCase],124wasmtime_crate: &syn::Path,125) -> Result<TokenStream>;126127fn expand_enum(128&self,129name: &syn::Ident,130discriminant_size: DiscriminantSize,131cases: &[VariantCase],132wasmtime_crate: &syn::Path,133) -> Result<TokenStream>;134}135136pub fn expand(expander: &dyn Expander, input: &DeriveInput) -> Result<TokenStream> {137let mut wasmtime_crate = None;138let mut style = None;139140for attribute in &input.attrs {141if !attribute.path().is_ident("component") {142continue;143}144match attribute.parse_args()? {145ComponentAttr::WasmtimeCrate(c) => wasmtime_crate = Some(c),146ComponentAttr::Style(attr_style) => {147if style.is_some() {148return Err(Error::new_spanned(149attribute,150"duplicate `component` attribute",151));152}153style = Some(attr_style);154}155}156}157158let style = style.ok_or_else(|| Error::new_spanned(input, "missing `component` attribute"))?;159let wasmtime_crate = wasmtime_crate.unwrap_or_else(default_wasmtime_crate);160match style {161Style::Record => expand_record(expander, input, &wasmtime_crate),162Style::Enum | Style::Variant => expand_variant(expander, input, style, &wasmtime_crate),163}164}165166fn default_wasmtime_crate() -> syn::Path {167Ident::new("wasmtime", Span::call_site()).into()168}169170fn expand_record(171expander: &dyn Expander,172input: &DeriveInput,173wasmtime_crate: &syn::Path,174) -> Result<TokenStream> {175let name = &input.ident;176177let body = if let Data::Struct(body) = &input.data {178body179} else {180return Err(Error::new(181name.span(),182"`record` component types can only be derived for Rust `struct`s",183));184};185186match &body.fields {187syn::Fields::Named(fields) => expander.expand_record(188&input.ident,189&input.generics,190&fields.named.iter().collect::<Vec<_>>(),191wasmtime_crate,192),193194syn::Fields::Unnamed(_) | syn::Fields::Unit => Err(Error::new(195name.span(),196"`record` component types can only be derived for `struct`s with named fields",197)),198}199}200201fn expand_variant(202expander: &dyn Expander,203input: &DeriveInput,204style: Style,205wasmtime_crate: &syn::Path,206) -> Result<TokenStream> {207let name = &input.ident;208209let body = if let Data::Enum(body) = &input.data {210body211} else {212return Err(Error::new(213name.span(),214format!("`{style}` component types can only be derived for Rust `enum`s"),215));216};217218if body.variants.is_empty() {219return Err(Error::new(220name.span(),221format!(222"`{style}` component types can only be derived for Rust `enum`s with at least one variant"223),224));225}226227let discriminant_size = DiscriminantSize::from_count(body.variants.len()).ok_or_else(|| {228Error::new(229input.ident.span(),230"`enum`s with more than 2^32 variants are not supported",231)232})?;233234let cases = body235.variants236.iter()237.map(238|syn::Variant {239attrs,240ident,241fields,242..243}| {244Ok(VariantCase {245attrs,246ident,247ty: match fields {248syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {249Some(&fields.unnamed[0].ty)250}251syn::Fields::Unit => None,252_ => {253return Err(Error::new(254name.span(),255format!(256"`{}` component types can only be derived for Rust `enum`s \257containing variants with {}",258style,259match style {260Style::Variant => "at most one unnamed field each",261Style::Enum => "no fields",262Style::Record => unreachable!(),263}264),265));266}267},268})269},270)271.collect::<Result<Vec<_>>>()?;272273match style {274Style::Variant => expander.expand_variant(275&input.ident,276&input.generics,277discriminant_size,278&cases,279wasmtime_crate,280),281Style::Enum => {282validate_enum(input, &body, discriminant_size)?;283expander.expand_enum(&input.ident, discriminant_size, &cases, wasmtime_crate)284}285Style::Record => unreachable!(),286}287}288289/// Validates component model `enum` definitions are accompanied with290/// appropriate `#[repr]` tags. Additionally requires that no discriminants are291/// listed to ensure that unsafe transmutes in lift are valid.292fn validate_enum(input: &DeriveInput, body: &syn::DataEnum, size: DiscriminantSize) -> Result<()> {293if !input.generics.params.is_empty() {294return Err(Error::new_spanned(295&input.generics.params,296"cannot have generics on an `enum`",297));298}299if let Some(clause) = &input.generics.where_clause {300return Err(Error::new_spanned(301clause,302"cannot have a where clause on an `enum`",303));304}305let expected_discr = match size {306DiscriminantSize::Size1 => "u8",307DiscriminantSize::Size2 => "u16",308DiscriminantSize::Size4 => "u32",309};310let mut found_repr = false;311for attr in input.attrs.iter() {312if !attr.meta.path().is_ident("repr") {313continue;314}315let list = attr.meta.require_list()?;316found_repr = true;317if list.tokens.to_string() != expected_discr {318return Err(Error::new_spanned(319&list.tokens,320format!(321"expected `repr({expected_discr})`, found `repr({})`",322list.tokens323),324));325}326}327if !found_repr {328return Err(Error::new_spanned(329&body.enum_token,330format!("missing required `#[repr({expected_discr})]`"),331));332}333334for case in body.variants.iter() {335if let Some((_, expr)) = &case.discriminant {336return Err(Error::new_spanned(337expr,338"cannot have an explicit discriminant",339));340}341}342343Ok(())344}345346fn expand_record_for_component_type(347name: &syn::Ident,348generics: &syn::Generics,349fields: &[&syn::Field],350typecheck: TokenStream,351typecheck_argument: TokenStream,352wt: &syn::Path,353) -> Result<TokenStream> {354let internal = quote!(#wt::component::__internal);355356let mut lower_generic_params = TokenStream::new();357let mut lower_generic_args = TokenStream::new();358let mut lower_field_declarations = TokenStream::new();359let mut abi_list = TokenStream::new();360let mut unique_types = HashSet::new();361362for (index, syn::Field { ident, ty, .. }) in fields.iter().enumerate() {363let generic = format_ident!("T{}", index);364365lower_generic_params.extend(quote!(#generic: Copy,));366lower_generic_args.extend(quote!(<#ty as #wt::component::ComponentType>::Lower,));367368lower_field_declarations.extend(quote!(#ident: #generic,));369370abi_list.extend(quote!(371<#ty as #wt::component::ComponentType>::ABI,372));373374unique_types.insert(ty);375}376377let generics = add_trait_bounds(generics, parse_quote!(#wt::component::ComponentType));378let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();379let lower = format_ident!("Lower{}", name);380381// You may wonder why we make the types of all the fields of the #lower struct generic. This is to work382// around the lack of [perfect derive support in383// rustc](https://smallcultfollowing.com/babysteps//blog/2022/04/12/implied-bounds-and-perfect-derive/#what-is-perfect-derive)384// as of this writing.385//386// If the struct we're deriving a `ComponentType` impl for has any generic parameters, then #lower needs387// generic parameters too. And if we just copy the parameters and bounds from the impl to #lower, then the388// `#[derive(Clone, Copy)]` will fail unless the original generics were declared with those bounds, which389// we don't want to require.390//391// Alternatively, we could just pass the `Lower` associated type of each generic type as arguments to392// #lower, but that would require distinguishing between generic and concrete types when generating393// #lower_field_declarations, which would require some form of symbol resolution. That doesn't seem worth394// the trouble.395396let expanded = quote! {397#[doc(hidden)]398#[derive(Clone, Copy)]399#[repr(C)]400pub struct #lower <#lower_generic_params> {401#lower_field_declarations402_align: [#wt::ValRaw; 0],403}404405unsafe impl #impl_generics #wt::component::ComponentType for #name #ty_generics #where_clause {406type Lower = #lower <#lower_generic_args>;407408const ABI: #internal::CanonicalAbiInfo =409#internal::CanonicalAbiInfo::record_static(&[#abi_list]);410411#[inline]412fn typecheck(413ty: &#internal::InterfaceType,414types: &#internal::InstanceType<'_>,415) -> #internal::anyhow::Result<()> {416#internal::#typecheck(ty, types, &[#typecheck_argument])417}418}419};420421Ok(quote!(const _: () = { #expanded };))422}423424fn quote(size: DiscriminantSize, discriminant: usize) -> TokenStream {425match size {426DiscriminantSize::Size1 => {427let discriminant = u8::try_from(discriminant).unwrap();428quote!(#discriminant)429}430DiscriminantSize::Size2 => {431let discriminant = u16::try_from(discriminant).unwrap();432quote!(#discriminant)433}434DiscriminantSize::Size4 => {435let discriminant = u32::try_from(discriminant).unwrap();436quote!(#discriminant)437}438}439}440441pub struct LiftExpander;442443impl Expander for LiftExpander {444fn expand_record(445&self,446name: &syn::Ident,447generics: &syn::Generics,448fields: &[&syn::Field],449wt: &syn::Path,450) -> Result<TokenStream> {451let internal = quote!(#wt::component::__internal);452453let mut lifts = TokenStream::new();454let mut loads = TokenStream::new();455456for (i, syn::Field { ident, ty, .. }) in fields.iter().enumerate() {457let field_ty = quote!(ty.fields[#i].ty);458lifts.extend(459quote!(#ident: <#ty as #wt::component::Lift>::linear_lift_from_flat(460cx, #field_ty, &src.#ident461)?,),462);463464loads.extend(465quote!(#ident: <#ty as #wt::component::Lift>::linear_lift_from_memory(466cx, #field_ty,467&bytes468[<#ty as #wt::component::ComponentType>::ABI.next_field32_size(&mut offset)..]469[..<#ty as #wt::component::ComponentType>::SIZE32]470)?,),471);472}473474let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lift));475let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();476477let extract_ty = quote! {478let ty = match ty {479#internal::InterfaceType::Record(i) => &cx.types[i],480_ => #internal::bad_type_info(),481};482};483484let expanded = quote! {485unsafe impl #impl_generics #wt::component::Lift for #name #ty_generics #where_clause {486#[inline]487fn linear_lift_from_flat(488cx: &mut #internal::LiftContext<'_>,489ty: #internal::InterfaceType,490src: &Self::Lower,491) -> #internal::anyhow::Result<Self> {492#extract_ty493Ok(Self {494#lifts495})496}497498#[inline]499fn linear_lift_from_memory(500cx: &mut #internal::LiftContext<'_>,501ty: #internal::InterfaceType,502bytes: &[u8],503) -> #internal::anyhow::Result<Self> {504#extract_ty505debug_assert!(506(bytes.as_ptr() as usize)507% (<Self as #wt::component::ComponentType>::ALIGN32 as usize)508== 0509);510let mut offset = 0;511Ok(Self {512#loads513})514}515}516};517518Ok(expanded)519}520521fn expand_variant(522&self,523name: &syn::Ident,524generics: &syn::Generics,525discriminant_size: DiscriminantSize,526cases: &[VariantCase],527wt: &syn::Path,528) -> Result<TokenStream> {529let internal = quote!(#wt::component::__internal);530531let mut lifts = TokenStream::new();532let mut loads = TokenStream::new();533534for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {535let index_u32 = u32::try_from(index).unwrap();536537let index_quoted = quote(discriminant_size, index);538539if let Some(ty) = ty {540let payload_ty = quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info));541lifts.extend(542quote!(#index_u32 => Self::#ident(<#ty as #wt::component::Lift>::linear_lift_from_flat(543cx, #payload_ty, unsafe { &src.payload.#ident }544)?),),545);546547loads.extend(548quote!(#index_quoted => Self::#ident(<#ty as #wt::component::Lift>::linear_lift_from_memory(549cx, #payload_ty, &payload[..<#ty as #wt::component::ComponentType>::SIZE32]550)?),),551);552} else {553lifts.extend(quote!(#index_u32 => Self::#ident,));554555loads.extend(quote!(#index_quoted => Self::#ident,));556}557}558559let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lift));560let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();561562let from_bytes = match discriminant_size {563DiscriminantSize::Size1 => quote!(bytes[0]),564DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),565DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),566};567568let extract_ty = quote! {569let ty = match ty {570#internal::InterfaceType::Variant(i) => &cx.types[i],571_ => #internal::bad_type_info(),572};573};574575let expanded = quote! {576unsafe impl #impl_generics #wt::component::Lift for #name #ty_generics #where_clause {577#[inline]578fn linear_lift_from_flat(579cx: &mut #internal::LiftContext<'_>,580ty: #internal::InterfaceType,581src: &Self::Lower,582) -> #internal::anyhow::Result<Self> {583#extract_ty584Ok(match src.tag.get_u32() {585#lifts586discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),587})588}589590#[inline]591fn linear_lift_from_memory(592cx: &mut #internal::LiftContext<'_>,593ty: #internal::InterfaceType,594bytes: &[u8],595) -> #internal::anyhow::Result<Self> {596let align = <Self as #wt::component::ComponentType>::ALIGN32;597debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);598let discrim = #from_bytes;599let payload_offset = <Self as #internal::ComponentVariant>::PAYLOAD_OFFSET32;600let payload = &bytes[payload_offset..];601#extract_ty602Ok(match discrim {603#loads604discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),605})606}607}608};609610Ok(expanded)611}612613fn expand_enum(614&self,615name: &syn::Ident,616discriminant_size: DiscriminantSize,617cases: &[VariantCase],618wt: &syn::Path,619) -> Result<TokenStream> {620let internal = quote!(#wt::component::__internal);621622let (from_bytes, discrim_ty) = match discriminant_size {623DiscriminantSize::Size1 => (quote!(bytes[0]), quote!(u8)),624DiscriminantSize::Size2 => (625quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),626quote!(u16),627),628DiscriminantSize::Size4 => (629quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),630quote!(u32),631),632};633let discrim_limit = proc_macro2::Literal::usize_unsuffixed(cases.len());634635let extract_ty = quote! {636let ty = match ty {637#internal::InterfaceType::Enum(i) => &cx.types[i],638_ => #internal::bad_type_info(),639};640};641642let expanded = quote! {643unsafe impl #wt::component::Lift for #name {644#[inline]645fn linear_lift_from_flat(646cx: &mut #internal::LiftContext<'_>,647ty: #internal::InterfaceType,648src: &Self::Lower,649) -> #internal::anyhow::Result<Self> {650#extract_ty651let discrim = src.tag.get_u32();652if discrim >= #discrim_limit {653#internal::anyhow::bail!("unexpected discriminant: {discrim}");654}655Ok(unsafe {656#internal::transmute::<#discrim_ty, #name>(discrim as #discrim_ty)657})658}659660#[inline]661fn linear_lift_from_memory(662cx: &mut #internal::LiftContext<'_>,663ty: #internal::InterfaceType,664bytes: &[u8],665) -> #internal::anyhow::Result<Self> {666let align = <Self as #wt::component::ComponentType>::ALIGN32;667debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);668let discrim = #from_bytes;669if discrim >= #discrim_limit {670#internal::anyhow::bail!("unexpected discriminant: {discrim}");671}672Ok(unsafe {673#internal::transmute::<#discrim_ty, #name>(discrim)674})675}676}677};678679Ok(expanded)680}681}682683pub struct LowerExpander;684685impl Expander for LowerExpander {686fn expand_record(687&self,688name: &syn::Ident,689generics: &syn::Generics,690fields: &[&syn::Field],691wt: &syn::Path,692) -> Result<TokenStream> {693let internal = quote!(#wt::component::__internal);694695let mut lowers = TokenStream::new();696let mut stores = TokenStream::new();697698for (i, syn::Field { ident, ty, .. }) in fields.iter().enumerate() {699let field_ty = quote!(ty.fields[#i].ty);700lowers.extend(quote!(#wt::component::Lower::linear_lower_to_flat(701&self.#ident, cx, #field_ty, #internal::map_maybe_uninit!(dst.#ident)702)?;));703704stores.extend(quote!(#wt::component::Lower::linear_lower_to_memory(705&self.#ident,706cx,707#field_ty,708<#ty as #wt::component::ComponentType>::ABI.next_field32_size(&mut offset),709)?;));710}711712let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lower));713let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();714715let extract_ty = quote! {716let ty = match ty {717#internal::InterfaceType::Record(i) => &cx.types[i],718_ => #internal::bad_type_info(),719};720};721722let expanded = quote! {723unsafe impl #impl_generics #wt::component::Lower for #name #ty_generics #where_clause {724#[inline]725fn linear_lower_to_flat<T>(726&self,727cx: &mut #internal::LowerContext<'_, T>,728ty: #internal::InterfaceType,729dst: &mut core::mem::MaybeUninit<Self::Lower>,730) -> #internal::anyhow::Result<()> {731#extract_ty732#lowers733Ok(())734}735736#[inline]737fn linear_lower_to_memory<T>(738&self,739cx: &mut #internal::LowerContext<'_, T>,740ty: #internal::InterfaceType,741mut offset: usize742) -> #internal::anyhow::Result<()> {743debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);744#extract_ty745#stores746Ok(())747}748}749};750751Ok(expanded)752}753754fn expand_variant(755&self,756name: &syn::Ident,757generics: &syn::Generics,758discriminant_size: DiscriminantSize,759cases: &[VariantCase],760wt: &syn::Path,761) -> Result<TokenStream> {762let internal = quote!(#wt::component::__internal);763764let mut lowers = TokenStream::new();765let mut stores = TokenStream::new();766767for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {768let index_u32 = u32::try_from(index).unwrap();769770let index_quoted = quote(discriminant_size, index);771772let discriminant_size = usize::from(discriminant_size);773774let pattern;775let lower;776let store;777778if ty.is_some() {779let ty = quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info));780pattern = quote!(Self::#ident(value));781lower = quote!(value.linear_lower_to_flat(cx, #ty, dst));782store = quote!(value.linear_lower_to_memory(783cx,784#ty,785offset + <Self as #internal::ComponentVariant>::PAYLOAD_OFFSET32,786));787} else {788pattern = quote!(Self::#ident);789lower = quote!(Ok(()));790store = quote!(Ok(()));791}792793lowers.extend(quote!(#pattern => {794#internal::map_maybe_uninit!(dst.tag).write(#wt::ValRaw::u32(#index_u32));795unsafe {796#internal::lower_payload(797#internal::map_maybe_uninit!(dst.payload),798|payload| #internal::map_maybe_uninit!(payload.#ident),799|dst| #lower,800)801}802}));803804stores.extend(quote!(#pattern => {805*cx.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes();806#store807}));808}809810let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lower));811let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();812813let extract_ty = quote! {814let ty = match ty {815#internal::InterfaceType::Variant(i) => &cx.types[i],816_ => #internal::bad_type_info(),817};818};819820let expanded = quote! {821unsafe impl #impl_generics #wt::component::Lower for #name #ty_generics #where_clause {822#[inline]823fn linear_lower_to_flat<T>(824&self,825cx: &mut #internal::LowerContext<'_, T>,826ty: #internal::InterfaceType,827dst: &mut core::mem::MaybeUninit<Self::Lower>,828) -> #internal::anyhow::Result<()> {829#extract_ty830match self {831#lowers832}833}834835#[inline]836fn linear_lower_to_memory<T>(837&self,838cx: &mut #internal::LowerContext<'_, T>,839ty: #internal::InterfaceType,840mut offset: usize841) -> #internal::anyhow::Result<()> {842#extract_ty843debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);844match self {845#stores846}847}848}849};850851Ok(expanded)852}853854fn expand_enum(855&self,856name: &syn::Ident,857discriminant_size: DiscriminantSize,858_cases: &[VariantCase],859wt: &syn::Path,860) -> Result<TokenStream> {861let internal = quote!(#wt::component::__internal);862863let extract_ty = quote! {864let ty = match ty {865#internal::InterfaceType::Enum(i) => &cx.types[i],866_ => #internal::bad_type_info(),867};868};869870let (size, ty) = match discriminant_size {871DiscriminantSize::Size1 => (1, quote!(u8)),872DiscriminantSize::Size2 => (2, quote!(u16)),873DiscriminantSize::Size4 => (4, quote!(u32)),874};875let size = proc_macro2::Literal::usize_unsuffixed(size);876877let expanded = quote! {878unsafe impl #wt::component::Lower for #name {879#[inline]880fn linear_lower_to_flat<T>(881&self,882cx: &mut #internal::LowerContext<'_, T>,883ty: #internal::InterfaceType,884dst: &mut core::mem::MaybeUninit<Self::Lower>,885) -> #internal::anyhow::Result<()> {886#extract_ty887#internal::map_maybe_uninit!(dst.tag)888.write(#wt::ValRaw::u32(*self as u32));889Ok(())890}891892#[inline]893fn linear_lower_to_memory<T>(894&self,895cx: &mut #internal::LowerContext<'_, T>,896ty: #internal::InterfaceType,897mut offset: usize898) -> #internal::anyhow::Result<()> {899#extract_ty900debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);901let discrim = *self as #ty;902*cx.get::<#size>(offset) = discrim.to_le_bytes();903Ok(())904}905}906};907908Ok(expanded)909}910}911912pub struct ComponentTypeExpander;913914impl Expander for ComponentTypeExpander {915fn expand_record(916&self,917name: &syn::Ident,918generics: &syn::Generics,919fields: &[&syn::Field],920wt: &syn::Path,921) -> Result<TokenStream> {922expand_record_for_component_type(923name,924generics,925fields,926quote!(typecheck_record),927fields928.iter()929.map(930|syn::Field {931attrs, ident, ty, ..932}| {933let name = find_rename(attrs)?.unwrap_or_else(|| {934let ident = ident.as_ref().unwrap();935syn::LitStr::new(&ident.to_string(), ident.span())936});937938Ok(quote!((#name, <#ty as #wt::component::ComponentType>::typecheck),))939},940)941.collect::<Result<_>>()?,942wt,943)944}945946fn expand_variant(947&self,948name: &syn::Ident,949generics: &syn::Generics,950_discriminant_size: DiscriminantSize,951cases: &[VariantCase],952wt: &syn::Path,953) -> Result<TokenStream> {954let internal = quote!(#wt::component::__internal);955956let mut case_names_and_checks = TokenStream::new();957let mut lower_payload_generic_params = TokenStream::new();958let mut lower_payload_generic_args = TokenStream::new();959let mut lower_payload_case_declarations = TokenStream::new();960let mut lower_generic_args = TokenStream::new();961let mut abi_list = TokenStream::new();962let mut unique_types = HashSet::new();963964for (index, VariantCase { attrs, ident, ty }) in cases.iter().enumerate() {965let rename = find_rename(attrs)?;966967let name = rename.unwrap_or_else(|| syn::LitStr::new(&ident.to_string(), ident.span()));968969if let Some(ty) = ty {970abi_list.extend(quote!(Some(<#ty as #wt::component::ComponentType>::ABI),));971972case_names_and_checks.extend(973quote!((#name, Some(<#ty as #wt::component::ComponentType>::typecheck)),),974);975976let generic = format_ident!("T{}", index);977978lower_payload_generic_params.extend(quote!(#generic: Copy,));979lower_payload_generic_args.extend(quote!(#generic,));980lower_payload_case_declarations.extend(quote!(#ident: #generic,));981lower_generic_args.extend(quote!(<#ty as #wt::component::ComponentType>::Lower,));982983unique_types.insert(ty);984} else {985abi_list.extend(quote!(None,));986case_names_and_checks.extend(quote!((#name, None),));987lower_payload_case_declarations.extend(quote!(#ident: [#wt::ValRaw; 0],));988}989}990991let generics = add_trait_bounds(generics, parse_quote!(#wt::component::ComponentType));992let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();993let lower = format_ident!("Lower{}", name);994let lower_payload = format_ident!("LowerPayload{}", name);995996// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union997// generic. This is to work around a [normalization bug in998// rustc](https://github.com/rust-lang/rust/issues/90903) such that the compiler does not understand that999// e.g. `<i32 as ComponentType>::Lower` is `Copy` despite the bound specified in `ComponentType`'s1000// definition.1001//1002// See also the comment in `Self::expand_record` above for another reason why we do this.10031004let expanded = quote! {1005#[doc(hidden)]1006#[derive(Clone, Copy)]1007#[repr(C)]1008pub struct #lower<#lower_payload_generic_params> {1009tag: #wt::ValRaw,1010payload: #lower_payload<#lower_payload_generic_args>1011}10121013#[doc(hidden)]1014#[allow(non_snake_case)]1015#[derive(Clone, Copy)]1016#[repr(C)]1017union #lower_payload<#lower_payload_generic_params> {1018#lower_payload_case_declarations1019}10201021unsafe impl #impl_generics #wt::component::ComponentType for #name #ty_generics #where_clause {1022type Lower = #lower<#lower_generic_args>;10231024#[inline]1025fn typecheck(1026ty: &#internal::InterfaceType,1027types: &#internal::InstanceType<'_>,1028) -> #internal::anyhow::Result<()> {1029#internal::typecheck_variant(ty, types, &[#case_names_and_checks])1030}10311032const ABI: #internal::CanonicalAbiInfo =1033#internal::CanonicalAbiInfo::variant_static(&[#abi_list]);1034}10351036unsafe impl #impl_generics #internal::ComponentVariant for #name #ty_generics #where_clause {1037const CASES: &'static [Option<#internal::CanonicalAbiInfo>] = &[#abi_list];1038}1039};10401041Ok(quote!(const _: () = { #expanded };))1042}10431044fn expand_enum(1045&self,1046name: &syn::Ident,1047_discriminant_size: DiscriminantSize,1048cases: &[VariantCase],1049wt: &syn::Path,1050) -> Result<TokenStream> {1051let internal = quote!(#wt::component::__internal);10521053let mut case_names = TokenStream::new();1054let mut abi_list = TokenStream::new();10551056for VariantCase { attrs, ident, ty } in cases.iter() {1057let rename = find_rename(attrs)?;10581059let name = rename.unwrap_or_else(|| syn::LitStr::new(&ident.to_string(), ident.span()));10601061if ty.is_some() {1062return Err(Error::new(1063ident.span(),1064"payloads are not permitted for `enum` cases",1065));1066}1067abi_list.extend(quote!(None,));1068case_names.extend(quote!(#name,));1069}10701071let lower = format_ident!("Lower{}", name);10721073let cases_len = cases.len();1074let expanded = quote! {1075#[doc(hidden)]1076#[derive(Clone, Copy)]1077#[repr(C)]1078pub struct #lower {1079tag: #wt::ValRaw,1080}10811082unsafe impl #wt::component::ComponentType for #name {1083type Lower = #lower;10841085#[inline]1086fn typecheck(1087ty: &#internal::InterfaceType,1088types: &#internal::InstanceType<'_>,1089) -> #internal::anyhow::Result<()> {1090#internal::typecheck_enum(ty, types, &[#case_names])1091}10921093const ABI: #internal::CanonicalAbiInfo =1094#internal::CanonicalAbiInfo::enum_(#cases_len);1095}10961097unsafe impl #internal::ComponentVariant for #name {1098const CASES: &'static [Option<#internal::CanonicalAbiInfo>] = &[#abi_list];1099}1100};11011102Ok(quote!(const _: () = { #expanded };))1103}1104}11051106#[derive(Debug)]1107struct Flag {1108rename: Option<String>,1109name: String,1110}11111112impl Parse for Flag {1113fn parse(input: ParseStream) -> Result<Self> {1114let attributes = syn::Attribute::parse_outer(input)?;11151116let rename = find_rename(&attributes)?.map(|literal| literal.value());11171118input.parse::<Token![const]>()?;1119let name = input.parse::<syn::Ident>()?.to_string();11201121Ok(Self { rename, name })1122}1123}11241125#[derive(Debug)]1126pub struct Flags {1127name: String,1128flags: Vec<Flag>,1129}11301131impl Parse for Flags {1132fn parse(input: ParseStream) -> Result<Self> {1133let name = input.parse::<syn::Ident>()?.to_string();11341135let content;1136braced!(content in input);11371138let flags = content1139.parse_terminated(Flag::parse, Token![;])?1140.into_iter()1141.collect();11421143Ok(Self { name, flags })1144}1145}11461147pub fn expand_flags(flags: &Flags) -> Result<TokenStream> {1148let wt = default_wasmtime_crate();1149let size = FlagsSize::from_count(flags.flags.len());11501151let ty;1152let eq;11531154let count = flags.flags.len();11551156match size {1157FlagsSize::Size0 => {1158ty = quote!(());1159eq = quote!(true);1160}1161FlagsSize::Size1 => {1162ty = quote!(u8);11631164eq = if count == 8 {1165quote!(self.__inner0.eq(&rhs.__inner0))1166} else {1167let mask = !(0xFF_u8 << count);11681169quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask)))1170};1171}1172FlagsSize::Size2 => {1173ty = quote!(u16);11741175eq = if count == 16 {1176quote!(self.__inner0.eq(&rhs.__inner0))1177} else {1178let mask = !(0xFFFF_u16 << count);11791180quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask)))1181};1182}1183FlagsSize::Size4Plus(n) => {1184ty = quote!(u32);11851186let comparisons = (0..(n - 1))1187.map(|index| {1188let field = format_ident!("__inner{}", index);11891190quote!(self.#field.eq(&rhs.#field) &&)1191})1192.collect::<TokenStream>();11931194let field = format_ident!("__inner{}", n - 1);11951196eq = if count % 32 == 0 {1197quote!(#comparisons self.#field.eq(&rhs.#field))1198} else {1199let mask = !(0xFFFF_FFFF_u32 << (count % 32));12001201quote!(#comparisons (self.#field & #mask).eq(&(rhs.#field & #mask)))1202}1203}1204}12051206let count;1207let mut as_array;1208let mut bitor;1209let mut bitor_assign;1210let mut bitand;1211let mut bitand_assign;1212let mut bitxor;1213let mut bitxor_assign;1214let mut not;12151216match size {1217FlagsSize::Size0 => {1218count = 0;1219as_array = quote!([]);1220bitor = quote!(Self {});1221bitor_assign = quote!();1222bitand = quote!(Self {});1223bitand_assign = quote!();1224bitxor = quote!(Self {});1225bitxor_assign = quote!();1226not = quote!(Self {});1227}1228FlagsSize::Size1 | FlagsSize::Size2 => {1229count = 1;1230as_array = quote!([self.__inner0 as u32]);1231bitor = quote!(Self {1232__inner0: self.__inner0.bitor(rhs.__inner0)1233});1234bitor_assign = quote!(self.__inner0.bitor_assign(rhs.__inner0));1235bitand = quote!(Self {1236__inner0: self.__inner0.bitand(rhs.__inner0)1237});1238bitand_assign = quote!(self.__inner0.bitand_assign(rhs.__inner0));1239bitxor = quote!(Self {1240__inner0: self.__inner0.bitxor(rhs.__inner0)1241});1242bitxor_assign = quote!(self.__inner0.bitxor_assign(rhs.__inner0));1243not = quote!(Self {1244__inner0: self.__inner0.not()1245});1246}1247FlagsSize::Size4Plus(n) => {1248count = usize::from(n);1249as_array = TokenStream::new();1250bitor = TokenStream::new();1251bitor_assign = TokenStream::new();1252bitand = TokenStream::new();1253bitand_assign = TokenStream::new();1254bitxor = TokenStream::new();1255bitxor_assign = TokenStream::new();1256not = TokenStream::new();12571258for index in 0..n {1259let field = format_ident!("__inner{}", index);12601261as_array.extend(quote!(self.#field,));1262bitor.extend(quote!(#field: self.#field.bitor(rhs.#field),));1263bitor_assign.extend(quote!(self.#field.bitor_assign(rhs.#field);));1264bitand.extend(quote!(#field: self.#field.bitand(rhs.#field),));1265bitand_assign.extend(quote!(self.#field.bitand_assign(rhs.#field);));1266bitxor.extend(quote!(#field: self.#field.bitxor(rhs.#field),));1267bitxor_assign.extend(quote!(self.#field.bitxor_assign(rhs.#field);));1268not.extend(quote!(#field: self.#field.not(),));1269}12701271as_array = quote!([#as_array]);1272bitor = quote!(Self { #bitor });1273bitand = quote!(Self { #bitand });1274bitxor = quote!(Self { #bitxor });1275not = quote!(Self { #not });1276}1277};12781279let name = format_ident!("{}", flags.name);12801281let mut constants = TokenStream::new();1282let mut rust_names = TokenStream::new();1283let mut component_names = TokenStream::new();12841285for (index, Flag { name, rename }) in flags.flags.iter().enumerate() {1286rust_names.extend(quote!(#name,));12871288let component_name = rename.as_ref().unwrap_or(name);1289component_names.extend(quote!(#component_name,));12901291let fields = match size {1292FlagsSize::Size0 => quote!(),1293FlagsSize::Size1 => {1294let init = 1_u8 << index;1295quote!(__inner0: #init)1296}1297FlagsSize::Size2 => {1298let init = 1_u16 << index;1299quote!(__inner0: #init)1300}1301FlagsSize::Size4Plus(n) => (0..n)1302.map(|i| {1303let field = format_ident!("__inner{}", i);13041305let init = if index / 32 == usize::from(i) {13061_u32 << (index % 32)1307} else {130801309};13101311quote!(#field: #init,)1312})1313.collect::<TokenStream>(),1314};13151316let name = format_ident!("{}", name);13171318constants.extend(quote!(pub const #name: Self = Self { #fields };));1319}13201321let generics = syn::Generics {1322lt_token: None,1323params: Punctuated::new(),1324gt_token: None,1325where_clause: None,1326};13271328let fields = {1329let ty = syn::parse2::<syn::Type>(ty.clone())?;13301331(0..count)1332.map(|index| syn::Field {1333attrs: Vec::new(),1334vis: syn::Visibility::Inherited,1335ident: Some(format_ident!("__inner{}", index)),1336colon_token: None,1337ty: ty.clone(),1338mutability: syn::FieldMutability::None,1339})1340.collect::<Vec<_>>()1341};13421343let fields = fields.iter().collect::<Vec<_>>();13441345let component_type_impl = expand_record_for_component_type(1346&name,1347&generics,1348&fields,1349quote!(typecheck_flags),1350component_names,1351&wt,1352)?;13531354let internal = quote!(#wt::component::__internal);13551356let field_names = fields1357.iter()1358.map(|syn::Field { ident, .. }| ident)1359.collect::<Vec<_>>();13601361let fields = fields1362.iter()1363.map(|syn::Field { ident, .. }| quote!(#[doc(hidden)] #ident: #ty,))1364.collect::<TokenStream>();13651366let (field_interface_type, field_size) = match size {1367FlagsSize::Size0 => (quote!(NOT USED), 0usize),1368FlagsSize::Size1 => (quote!(#internal::InterfaceType::U8), 1),1369FlagsSize::Size2 => (quote!(#internal::InterfaceType::U16), 2),1370FlagsSize::Size4Plus(_) => (quote!(#internal::InterfaceType::U32), 4),1371};13721373let expanded = quote! {1374#[derive(Copy, Clone, Default)]1375pub struct #name { #fields }13761377impl #name {1378#constants13791380pub fn as_array(&self) -> [u32; #count] {1381#as_array1382}13831384pub fn empty() -> Self {1385Self::default()1386}13871388pub fn all() -> Self {1389use core::ops::Not;1390Self::default().not()1391}13921393pub fn contains(&self, other: Self) -> bool {1394*self & other == other1395}13961397pub fn intersects(&self, other: Self) -> bool {1398*self & other != Self::empty()1399}1400}14011402impl core::cmp::PartialEq for #name {1403fn eq(&self, rhs: &#name) -> bool {1404#eq1405}1406}14071408impl core::cmp::Eq for #name { }14091410impl core::fmt::Debug for #name {1411fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {1412#internal::format_flags(&self.as_array(), &[#rust_names], f)1413}1414}14151416impl core::ops::BitOr for #name {1417type Output = #name;14181419fn bitor(self, rhs: #name) -> #name {1420#bitor1421}1422}14231424impl core::ops::BitOrAssign for #name {1425fn bitor_assign(&mut self, rhs: #name) {1426#bitor_assign1427}1428}14291430impl core::ops::BitAnd for #name {1431type Output = #name;14321433fn bitand(self, rhs: #name) -> #name {1434#bitand1435}1436}14371438impl core::ops::BitAndAssign for #name {1439fn bitand_assign(&mut self, rhs: #name) {1440#bitand_assign1441}1442}14431444impl core::ops::BitXor for #name {1445type Output = #name;14461447fn bitxor(self, rhs: #name) -> #name {1448#bitxor1449}1450}14511452impl core::ops::BitXorAssign for #name {1453fn bitxor_assign(&mut self, rhs: #name) {1454#bitxor_assign1455}1456}14571458impl core::ops::Not for #name {1459type Output = #name;14601461fn not(self) -> #name {1462#not1463}1464}14651466#component_type_impl14671468unsafe impl #wt::component::Lower for #name {1469fn linear_lower_to_flat<T>(1470&self,1471cx: &mut #internal::LowerContext<'_, T>,1472_ty: #internal::InterfaceType,1473dst: &mut core::mem::MaybeUninit<Self::Lower>,1474) -> #internal::anyhow::Result<()> {1475#(1476self.#field_names.linear_lower_to_flat(1477cx,1478#field_interface_type,1479#internal::map_maybe_uninit!(dst.#field_names),1480)?;1481)*1482Ok(())1483}14841485fn linear_lower_to_memory<T>(1486&self,1487cx: &mut #internal::LowerContext<'_, T>,1488_ty: #internal::InterfaceType,1489mut offset: usize1490) -> #internal::anyhow::Result<()> {1491debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);1492#(1493self.#field_names.linear_lower_to_memory(1494cx,1495#field_interface_type,1496offset,1497)?;1498offset += core::mem::size_of_val(&self.#field_names);1499)*1500Ok(())1501}1502}15031504unsafe impl #wt::component::Lift for #name {1505fn linear_lift_from_flat(1506cx: &mut #internal::LiftContext<'_>,1507_ty: #internal::InterfaceType,1508src: &Self::Lower,1509) -> #internal::anyhow::Result<Self> {1510Ok(Self {1511#(1512#field_names: #wt::component::Lift::linear_lift_from_flat(1513cx,1514#field_interface_type,1515&src.#field_names,1516)?,1517)*1518})1519}15201521fn linear_lift_from_memory(1522cx: &mut #internal::LiftContext<'_>,1523_ty: #internal::InterfaceType,1524bytes: &[u8],1525) -> #internal::anyhow::Result<Self> {1526debug_assert!(1527(bytes.as_ptr() as usize)1528% (<Self as #wt::component::ComponentType>::ALIGN32 as usize)1529== 01530);1531#(1532let (field, bytes) = bytes.split_at(#field_size);1533let #field_names = #wt::component::Lift::linear_lift_from_memory(1534cx,1535#field_interface_type,1536field,1537)?;1538)*1539Ok(Self { #(#field_names,)* })1540}1541}1542};15431544Ok(expanded)1545}154615471548