Path: blob/main/bit_field/bit_field_derive/src/bit_field_derive.rs
5394 views
// Copyright 2018 The ChromiumOS Authors1// Use of this source code is governed by a BSD-style license that can be2// found in the LICENSE file.34#![recursion_limit = "256"]56extern crate proc_macro;78use proc_macro2::Span;9use proc_macro2::TokenStream;10use quote::quote;11use quote::quote_spanned;12use syn::parse::Error;13use syn::parse::Result;14use syn::parse_macro_input;15use syn::Attribute;16use syn::Data;17use syn::DataEnum;18use syn::DeriveInput;19use syn::Fields;20use syn::FieldsNamed;21use syn::FieldsUnnamed;22use syn::Ident;23use syn::Lit;24use syn::LitInt;25use syn::Meta;26use syn::MetaNameValue;27use syn::Type;28use syn::Visibility;2930/// The function that derives the actual implementation.31#[proc_macro_attribute]32pub fn bitfield(33_args: proc_macro::TokenStream,34input: proc_macro::TokenStream,35) -> proc_macro::TokenStream {36let derive_input = parse_macro_input!(input as DeriveInput);3738let expanded = bitfield_impl(&derive_input).unwrap_or_else(|err| {39let compile_error = err.to_compile_error();40quote! {41#compile_error4243// Include the original input to avoid "use of undeclared type"44// errors elsewhere.45#derive_input46}47});4849expanded.into()50}5152fn bitfield_impl(ast: &DeriveInput) -> Result<TokenStream> {53if !ast.generics.params.is_empty() {54return Err(Error::new(55Span::call_site(),56"#[bitfield] does not support generic parameters",57));58}5960match &ast.data {61Data::Struct(data_struct) => match &data_struct.fields {62Fields::Named(fields_named) => bitfield_struct_impl(ast, fields_named),63Fields::Unnamed(fields_unnamed) => bitfield_tuple_struct_impl(ast, fields_unnamed),64Fields::Unit => Err(Error::new(65Span::call_site(),66"#[bitfield] does not work with unit struct",67)),68},69Data::Enum(data_enum) => bitfield_enum_impl(ast, data_enum),70Data::Union(_) => Err(Error::new(71Span::call_site(),72"#[bitfield] does not support unions",73)),74}75}7677fn bitfield_tuple_struct_impl(ast: &DeriveInput, fields: &FieldsUnnamed) -> Result<TokenStream> {78let mut ast = ast.clone();79let width = match parse_remove_bits_attr(&mut ast)? {80Some(w) => w,81None => {82return Err(Error::new(83Span::call_site(),84"tuple struct field must have bits attribute",85));86}87};8889let ident = &ast.ident;9091if width > 64 {92return Err(Error::new(93Span::call_site(),94"max width of bitfield field is 64",95));96}9798let bits = width as u8;99100if fields.unnamed.len() != 1 {101return Err(Error::new(102Span::call_site(),103"tuple struct field must have exactly 1 field",104));105}106107let field_type = match &fields.unnamed.first().unwrap().ty {108Type::Path(t) => t,109_ => {110return Err(Error::new(111Span::call_site(),112"tuple struct field must have primitive field",113));114}115};116let span = field_type.path.segments.first().unwrap().ident.span();117118let from_u64 = quote_spanned! {119span => val as #field_type120};121122let into_u64 = quote_spanned! {123span => val.0 as u64124};125126let expanded = quote! {127#ast128129impl bit_field::BitFieldSpecifier for #ident {130const FIELD_WIDTH: u8 = #bits;131type SetterType = Self;132type GetterType = Self;133134#[inline]135#[allow(clippy::unnecessary_cast)]136fn from_u64(val: u64) -> Self::GetterType {137Self(#from_u64)138}139140#[inline]141#[allow(clippy::unnecessary_cast)]142fn into_u64(val: Self::SetterType) -> u64 {143#into_u64144}145}146};147148Ok(expanded)149}150151fn bitfield_enum_impl(ast: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {152let mut ast = ast.clone();153let width = parse_remove_bits_attr(&mut ast)?;154match width {155None => bitfield_enum_without_width_impl(&ast, data),156Some(width) => bitfield_enum_with_width_impl(&ast, data, width),157}158}159160fn bitfield_enum_with_width_impl(161ast: &DeriveInput,162data: &DataEnum,163width: u64,164) -> Result<TokenStream> {165if width > 64 {166return Err(Error::new(167Span::call_site(),168"max width of bitfield enum is 64",169));170}171let bits = width as u8;172let declare_discriminants = get_declare_discriminants_for_enum(bits, ast, data);173174let ident = &ast.ident;175let type_name = ident.to_string();176let variants = &data.variants;177let match_discriminants = variants.iter().map(|variant| {178let variant = &variant.ident;179quote! {180discriminant::#variant => Ok(#ident::#variant),181}182});183184let expanded = quote! {185#ast186187impl bit_field::BitFieldSpecifier for #ident {188const FIELD_WIDTH: u8 = #bits;189type SetterType = Self;190type GetterType = std::result::Result<Self, bit_field::Error>;191192#[inline]193fn from_u64(val: u64) -> Self::GetterType {194struct discriminant;195impl discriminant {196#(#declare_discriminants)*197}198match val {199#(#match_discriminants)*200v => Err(bit_field::Error::new(#type_name, v)),201}202}203204#[inline]205fn into_u64(val: Self::SetterType) -> u64 {206val as u64207}208}209};210211Ok(expanded)212}213// Expand to an impl of BitFieldSpecifier for an enum like:214//215// #[bitfield]216// #[derive(Debug, PartialEq)]217// enum TwoBits {218// Zero = 0b00,219// One = 0b01,220// Two = 0b10,221// Three = 0b11,222// }223//224// Such enums may be used as a field of a bitfield struct.225//226// #[bitfield]227// struct Struct {228// prefix: BitField1,229// two_bits: TwoBits,230// suffix: BitField5,231// }232//233fn bitfield_enum_without_width_impl(ast: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {234let ident = &ast.ident;235let variants = &data.variants;236let len = variants.len();237if len.count_ones() != 1 {238return Err(Error::new(239Span::call_site(),240"#[bitfield] expected a number of variants which is a power of 2 when bits is not \241specified for the enum",242));243}244245let bits = len.trailing_zeros() as u8;246let declare_discriminants = get_declare_discriminants_for_enum(bits, ast, data);247248let match_discriminants = variants.iter().map(|variant| {249let variant = &variant.ident;250quote! {251discriminant::#variant => #ident::#variant,252}253});254255let expanded = quote! {256#ast257258impl bit_field::BitFieldSpecifier for #ident {259const FIELD_WIDTH: u8 = #bits;260type SetterType = Self;261type GetterType = Self;262263#[inline]264fn from_u64(val: u64) -> Self::GetterType {265struct discriminant;266impl discriminant {267#(#declare_discriminants)*268}269match val {270#(#match_discriminants)*271_ => unreachable!(),272}273}274275#[inline]276fn into_u64(val: Self::SetterType) -> u64 {277val as u64278}279}280};281282Ok(expanded)283}284285fn get_declare_discriminants_for_enum(286bits: u8,287ast: &DeriveInput,288data: &DataEnum,289) -> Vec<TokenStream> {290let variants = &data.variants;291let upper_bound = 2u64.pow(bits as u32);292let ident = &ast.ident;293294variants295.iter()296.map(|variant| {297let variant = &variant.ident;298let span = variant.span();299300let assertion = quote_spanned! {span=>301// If IS_IN_BOUNDS is true, this evaluates to 0.302//303// If IS_IN_BOUNDS is false, this evaluates to `0 - 1` which304// triggers a compile error on underflow when referenced below. The305// error is not beautiful but does carry the span of the problematic306// enum variant so at least it points to the right line.307//308// error: any use of this value will cause an error309// --> bit_field/test.rs:10:5310// |311// 10 | OutOfBounds = 0b111111,312// | ^^^^^^^^^^^ attempt to subtract with overflow313// |314//315// error[E0080]: erroneous constant used316// --> bit_field/test.rs:5:1317// |318// 5 | #[bitfield]319// | ^^^^^^^^^^^ referenced constant has errors320//321const ASSERT: u64 = 0 - !IS_IN_BOUNDS as u64;322};323324quote! {325#[allow(non_upper_case_globals)]326const #variant: u64 = {327const IS_IN_BOUNDS: bool = (#ident::#variant as u64) < #upper_bound;328329#assertion330331#ident::#variant as u64 + ASSERT332};333}334})335.collect()336}337338fn bitfield_struct_impl(ast: &DeriveInput, fields: &FieldsNamed) -> Result<TokenStream> {339let name = &ast.ident;340let vis = &ast.vis;341let attrs = &ast.attrs;342let fields = get_struct_fields(fields)?;343let struct_def = get_struct_def(vis, name, &fields);344let bits_impl = get_bits_impl(name);345let fields_impl = get_fields_impl(&fields);346let debug_fmt_impl = get_debug_fmt_impl(name, &fields);347348let expanded = quote! {349#(#attrs)*350#struct_def351#bits_impl352impl #name {353#(#fields_impl)*354}355#debug_fmt_impl356};357358Ok(expanded)359}360361struct FieldSpec<'a> {362ident: &'a Ident,363ty: &'a Type,364expected_bits: Option<LitInt>,365}366367// Unwrap ast to get the named fields. We only care about field names and types:368// "myfield : BitField3" -> ("myfield", Token(BitField3))369fn get_struct_fields(fields: &FieldsNamed) -> Result<Vec<FieldSpec>> {370let mut vec = Vec::new();371372for field in &fields.named {373let ident = field374.ident375.as_ref()376.expect("Fields::Named has named fields");377let ty = &field.ty;378let expected_bits = parse_bits_attr(&field.attrs)?;379vec.push(FieldSpec {380ident,381ty,382expected_bits,383});384}385386Ok(vec)387}388389// For example: #[bits = 1]390fn parse_bits_attr(attrs: &[Attribute]) -> Result<Option<LitInt>> {391let mut expected_bits = None;392393for attr in attrs {394if attr.path().is_ident("doc") {395continue;396}397if let Some(v) = try_parse_bits_attr(attr) {398expected_bits = Some(v);399continue;400}401402return Err(Error::new_spanned(attr, "unrecognized attribute"));403}404405Ok(expected_bits)406}407408// This function will return None if the attribute is not #[bits = *].409fn try_parse_bits_attr(attr: &Attribute) -> Option<LitInt> {410if attr.path().is_ident("bits") {411if let Meta::NameValue(MetaNameValue {412value:413syn::Expr::Lit(syn::ExprLit {414lit: Lit::Int(int), ..415}),416..417}) = &attr.meta418{419return Some(int).cloned();420}421}422None423}424425fn parse_remove_bits_attr(ast: &mut DeriveInput) -> Result<Option<u64>> {426let mut width = None;427let mut bits_idx = 0;428429for (i, attr) in ast.attrs.iter().enumerate() {430if let Some(w) = try_parse_bits_attr(attr) {431bits_idx = i;432width = Some(w.base10_parse()?);433}434}435436if width.is_some() {437ast.attrs.remove(bits_idx);438}439440Ok(width)441}442443fn get_struct_def(vis: &Visibility, name: &Ident, fields: &[FieldSpec]) -> TokenStream {444let mut field_types = Vec::new();445for spec in fields {446field_types.push(spec.ty);447}448449// `(BitField1::FIELD_WIDTH + BitField3::FIELD_WIDTH + ...)`450let data_size_in_bits = quote! {451(452#(453<#field_types as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize454)+*455)456};457458quote! {459#[repr(C)]460#vis struct #name {461data: [u8; #data_size_in_bits / 8],462}463464impl #name {465pub fn new() -> #name {466let _: ::bit_field::Check<[u8; #data_size_in_bits % 8]>;467468#name {469data: [0; #data_size_in_bits / 8],470}471}472}473}474}475476// Implement setter and getter for all fields.477fn get_fields_impl(fields: &[FieldSpec]) -> Vec<TokenStream> {478let mut impls = Vec::new();479// This vec keeps track of types before this field, used to generate the offset.480let current_types = &mut vec![quote!(::bit_field::BitField0)];481482for spec in fields {483let ty = spec.ty;484let getter_ident = Ident::new(format!("get_{}", spec.ident).as_str(), Span::call_site());485let setter_ident = Ident::new(format!("set_{}", spec.ident).as_str(), Span::call_site());486487// Optional #[bits = N] attribute to provide compile-time checked488// documentation of how many bits some field covers.489let check_expected_bits = spec.expected_bits.as_ref().map(|expected_bits| {490// If expected_bits does not match the actual number of bits in the491// bit field specifier, this will fail to compile with an error492// pointing into the #[bits = N] attribute.493let span = expected_bits.span();494quote_spanned! {span=>495#[allow(dead_code)]496const EXPECTED_BITS: [(); #expected_bits] =497[(); <#ty as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize];498}499});500501impls.push(quote! {502pub fn #getter_ident(&self) -> <#ty as ::bit_field::BitFieldSpecifier>::GetterType {503#check_expected_bits504let offset = #(<#current_types as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize)+*;505let val = self.get(offset, <#ty as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH);506<#ty as ::bit_field::BitFieldSpecifier>::from_u64(val)507}508509pub fn #setter_ident(&mut self, val: <#ty as ::bit_field::BitFieldSpecifier>::SetterType) {510let val = <#ty as ::bit_field::BitFieldSpecifier>::into_u64(val);511debug_assert!(val <= ::bit_field::max::<#ty>());512let offset = #(<#current_types as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize)+*;513self.set(offset, <#ty as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH, val)514}515});516517current_types.push(quote!(#ty));518}519520impls521}522523// Implement setter and getter for all fields.524fn get_debug_fmt_impl(name: &Ident, fields: &[FieldSpec]) -> TokenStream {525// print fields:526let mut impls = Vec::new();527for spec in fields {528let field_name = spec.ident.to_string();529let getter_ident = Ident::new(&format!("get_{}", spec.ident), Span::call_site());530impls.push(quote! {531.field(#field_name, &self.#getter_ident())532});533}534535let name_str = format!("{name}");536quote! {537impl std::fmt::Debug for #name {538fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {539f.debug_struct(#name_str)540#(#impls)*541.finish()542}543}544}545}546547fn get_bits_impl(name: &Ident) -> TokenStream {548quote! {549impl #name {550#[inline]551fn check_access(&self, offset: usize, width: u8) {552debug_assert!(width <= 64);553debug_assert!(offset / 8 < self.data.len());554debug_assert!((offset + (width as usize)) <= (self.data.len() * 8));555}556557#[inline]558pub fn get_bit(&self, offset: usize) -> bool {559self.check_access(offset, 1);560561let byte_index = offset / 8;562let bit_offset = offset % 8;563564let byte = self.data[byte_index];565let mask = 1 << bit_offset;566567byte & mask == mask568}569570#[inline]571pub fn set_bit(&mut self, offset: usize, val: bool) {572self.check_access(offset, 1);573574let byte_index = offset / 8;575let bit_offset = offset % 8;576577let byte = &mut self.data[byte_index];578let mask = 1 << bit_offset;579580if val {581*byte |= mask;582} else {583*byte &= !mask;584}585}586587#[inline]588pub fn get(&self, offset: usize, width: u8) -> u64 {589self.check_access(offset, width);590let mut val = 0;591592for i in 0..(width as usize) {593if self.get_bit(i + offset) {594val |= 1 << i;595}596}597598val599}600601#[inline]602pub fn set(&mut self, offset: usize, width: u8, val: u64) {603self.check_access(offset, width);604605for i in 0..(width as usize) {606let mask = 1 << i;607let val_bit_is_set = val & mask == mask;608self.set_bit(i + offset, val_bit_is_set);609}610}611}612}613}614615// Only intended to be used from the bit_field crate. This macro emits the616// marker types bit_field::BitField0 through bit_field::BitField64.617#[proc_macro]618#[doc(hidden)]619pub fn define_bit_field_specifiers(_input: proc_macro::TokenStream) -> proc_macro::TokenStream {620let mut code = TokenStream::new();621622for width in 0u8..=64 {623let span = Span::call_site();624let long_name = Ident::new(&format!("BitField{width}"), span);625let short_name = Ident::new(&format!("B{width}"), span);626627let default_field_type = if width <= 8 {628quote!(u8)629} else if width <= 16 {630quote!(u16)631} else if width <= 32 {632quote!(u32)633} else {634quote!(u64)635};636637code.extend(quote! {638pub struct #long_name;639pub use self::#long_name as #short_name;640641impl BitFieldSpecifier for #long_name {642const FIELD_WIDTH: u8 = #width;643type SetterType = #default_field_type;644type GetterType = #default_field_type;645646#[inline]647fn from_u64(val: u64) -> Self::GetterType {648val as Self::GetterType649}650651#[inline]652fn into_u64(val: Self::SetterType) -> u64 {653val as u64654}655}656});657}658659code.into()660}661662#[cfg(test)]663mod tests {664use syn::parse_quote;665666use super::*;667668#[test]669fn end_to_end() {670let input: DeriveInput = parse_quote! {671#[derive(Clone)]672struct MyBitField {673a: BitField1,674b: BitField2,675c: BitField5,676}677};678679let expected = quote! {680#[derive(Clone)]681#[repr(C)]682struct MyBitField {683data: [u8; (<BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize684+ <BitField2 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize685+ <BitField5 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize)686/ 8],687}688impl MyBitField {689pub fn new() -> MyBitField {690let _: ::bit_field::Check<[691u8;692(<BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize693+ <BitField2 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize694+ <BitField5 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize)695% 8696]>;697698MyBitField {699data: [0; (<BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize700+ <BitField2 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize701+ <BitField5 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize)702/ 8],703}704}705}706impl MyBitField {707#[inline]708fn check_access(&self, offset: usize, width: u8) {709debug_assert!(width <= 64);710debug_assert!(offset / 8 < self.data.len());711debug_assert!((offset + (width as usize)) <= (self.data.len() * 8));712}713#[inline]714pub fn get_bit(&self, offset: usize) -> bool {715self.check_access(offset, 1);716let byte_index = offset / 8;717let bit_offset = offset % 8;718let byte = self.data[byte_index];719let mask = 1 << bit_offset;720byte & mask == mask721}722#[inline]723pub fn set_bit(&mut self, offset: usize, val: bool) {724self.check_access(offset, 1);725let byte_index = offset / 8;726let bit_offset = offset % 8;727let byte = &mut self.data[byte_index];728let mask = 1 << bit_offset;729if val {730*byte |= mask;731} else {732*byte &= !mask;733}734}735#[inline]736pub fn get(&self, offset: usize, width: u8) -> u64 {737self.check_access(offset, width);738let mut val = 0;739for i in 0..(width as usize) {740if self.get_bit(i + offset) {741val |= 1 << i;742}743}744val745}746#[inline]747pub fn set(&mut self, offset: usize, width: u8, val: u64) {748self.check_access(offset, width);749for i in 0..(width as usize) {750let mask = 1 << i;751let val_bit_is_set = val & mask == mask;752self.set_bit(i + offset, val_bit_is_set);753}754}755}756impl MyBitField {757pub fn get_a(&self) -> <BitField1 as ::bit_field::BitFieldSpecifier>::GetterType {758let offset = <::bit_field::BitField0 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize;759let val = self.get(offset, <BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH);760<BitField1 as ::bit_field::BitFieldSpecifier>::from_u64(val)761}762pub fn set_a(&mut self, val: <BitField1 as ::bit_field::BitFieldSpecifier>::SetterType) {763let val = <BitField1 as ::bit_field::BitFieldSpecifier>::into_u64(val);764debug_assert!(val <= ::bit_field::max::<BitField1>());765let offset = <::bit_field::BitField0 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize;766self.set(offset, <BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH, val)767}768pub fn get_b(&self) -> <BitField2 as ::bit_field::BitFieldSpecifier>::GetterType {769let offset = <::bit_field::BitField0 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize770+ <BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize;771let val = self.get(offset, <BitField2 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH);772<BitField2 as ::bit_field::BitFieldSpecifier>::from_u64(val)773}774pub fn set_b(&mut self, val: <BitField2 as ::bit_field::BitFieldSpecifier>::SetterType) {775let val = <BitField2 as ::bit_field::BitFieldSpecifier>::into_u64(val);776debug_assert!(val <= ::bit_field::max::<BitField2>());777let offset = <::bit_field::BitField0 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize778+ <BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize;779self.set(offset, <BitField2 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH, val)780}781pub fn get_c(&self) -> <BitField5 as ::bit_field::BitFieldSpecifier>::GetterType {782let offset = <::bit_field::BitField0 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize783+ <BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize784+ <BitField2 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize;785let val = self.get(offset, <BitField5 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH);786<BitField5 as ::bit_field::BitFieldSpecifier>::from_u64(val)787}788pub fn set_c(&mut self, val: <BitField5 as ::bit_field::BitFieldSpecifier>::SetterType) {789let val = <BitField5 as ::bit_field::BitFieldSpecifier>::into_u64(val);790debug_assert!(val <= ::bit_field::max::<BitField5>());791let offset = <::bit_field::BitField0 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize792+ <BitField1 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize793+ <BitField2 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize;794self.set(offset, <BitField5 as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH, val)795}796}797impl std::fmt::Debug for MyBitField {798fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {799f.debug_struct("MyBitField")800.field("a", &self.get_a())801.field("b", &self.get_b())802.field("c", &self.get_c())803.finish()804}805}806};807808assert_eq!(809bitfield_impl(&input).unwrap().to_string(),810expected.to_string()811);812}813}814815816