Path: blob/main/pyo3-polars/pyo3-polars-derive/src/lib.rs
6939 views
mod attr;1mod keywords;23use proc_macro::TokenStream;4use quote::quote;5use syn::{parse_macro_input, FnArg};67fn quote_get_kwargs() -> proc_macro2::TokenStream {8quote!(9let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);1011let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {12Ok(value) => value,13Err(err) => {14let err = polars_error::polars_err!(InvalidOperation: "could not parse kwargs: '{}'\n\nCheck: registration of kwargs in the plugin.", err);15pyo3_polars::derive::_update_last_error(err);16return;17}18};1920)21}2223fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {24let kwargs = quote_get_kwargs();25quote!(26// parse the kwargs and assign to `let kwargs`27#kwargs2829// define the function30#ast3132// call the function33let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, kwargs);3435)36}3738fn quote_call_context(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {39quote!(40let context = *context;4142// define the function43#ast4445// call the function46let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context);47)48}4950fn quote_call_context_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {51quote!(52let context = *context;5354let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);5556let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {57Ok(value) => value,58Err(err) => {59pyo3_polars::derive::_update_last_error(err);60return;61}62};6364// define the function65#ast6667// call the function68let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context, kwargs);69)70}7172fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {73quote!(74// define the function75#ast76// call the function77let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs);78)79}8081fn quote_process_results() -> proc_macro2::TokenStream {82quote!(match result {83Ok(out) => {84// Update return value.85*return_value = polars_ffi::version_0::export_series(&out);86},87Err(err) => {88// Set latest error, but leave return value in empty state.89pyo3_polars::derive::_update_last_error(err);90},91})92}9394fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {95// count how often the user define a kwargs argument.96let args = ast97.sig98.inputs99.iter()100.skip(1)101.map(|fn_arg| {102if let FnArg::Typed(pat) = fn_arg {103if let syn::Pat::Ident(pat) = pat.pat.as_ref() {104pat.ident.to_string()105} else {106panic!("expected an argument")107}108} else {109panic!("expected a type argument")110}111})112.collect::<Vec<_>>();113114let fn_name = &ast.sig.ident;115116// Get the tokenstream of the call logic.117let quote_call = match args.len() {1180 => quote_call_no_kwargs(&ast, fn_name),1191 => match args[0].as_str() {120"kwargs" => quote_call_kwargs(&ast, fn_name),121"context" => quote_call_context(&ast, fn_name),122a => panic!("didn't expect argument {a}"),123},1242 => match (args[0].as_str(), args[1].as_str()) {125("context", "kwargs") => quote_call_context_kwargs(&ast, fn_name),126("kwargs", "context") => panic!("'kwargs', 'context' order should be reversed"),127(a, b) => panic!("didn't expect arguments {a}, {b}"),128},129_ => panic!("didn't expect so many arguments"),130};131132let quote_process_result = quote_process_results();133let fn_name = get_expression_function_name(fn_name);134135quote!(136use ::pyo3_polars::export::*;137138// create the outer public function139#[no_mangle]140pub unsafe extern "C" fn #fn_name (141e: *mut polars_ffi::version_0::SeriesExport,142input_len: usize,143kwargs_ptr: *const u8,144kwargs_len: usize,145return_value: *mut polars_ffi::version_0::SeriesExport,146context: *mut polars_ffi::version_0::CallerContext147) {148let panic_result = std::panic::catch_unwind(move || {149let inputs = polars_ffi::version_0::import_series_buffer(e, input_len).unwrap();150151#quote_call152153#quote_process_result154});155156if panic_result.is_err() {157// Set latest to panic;158::pyo3_polars::derive::_set_panic();159}160}161)162}163164fn get_field_function_name(fn_name: &syn::Ident) -> syn::Ident {165syn::Ident::new(&format!("_polars_plugin_field_{fn_name}"), fn_name.span())166}167168fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident {169syn::Ident::new(&format!("_polars_plugin_{fn_name}"), fn_name.span())170}171172fn quote_get_inputs() -> proc_macro2::TokenStream {173quote!(174let inputs = std::slice::from_raw_parts(field, len);175let inputs = inputs.iter().map(|field| {176let field = polars_arrow::ffi::import_field_from_c(field).unwrap();177let out = polars_core::prelude::Field::from(&field);178out179}).collect::<Vec<_>>();180)181}182183fn create_field_function(184fn_name: &syn::Ident,185dtype_fn_name: &syn::Ident,186kwargs: bool,187) -> proc_macro2::TokenStream {188let map_field_name = get_field_function_name(fn_name);189let inputs = quote_get_inputs();190191let call_fn = if kwargs {192let kwargs = quote_get_kwargs();193quote! (194#kwargs195let result = #dtype_fn_name(&inputs, kwargs);196)197} else {198quote!(199let result = #dtype_fn_name(&inputs);200)201};202203quote! (204#[no_mangle]205pub unsafe extern "C" fn #map_field_name(206field: *mut polars_arrow::ffi::ArrowSchema,207len: usize,208return_value: *mut polars_arrow::ffi::ArrowSchema,209kwargs_ptr: *const u8,210kwargs_len: usize,211) {212let panic_result = std::panic::catch_unwind(move || {213#inputs;214215#call_fn;216217match result {218Ok(out) => {219let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));220*return_value = out;221},222Err(err) => {223// Set latest error, but leave return value in empty state.224pyo3_polars::derive::_update_last_error(err);225}226}227});228229if panic_result.is_err() {230// Set latest to panic;231pyo3_polars::derive::_set_panic();232}233}234)235}236237fn create_field_function_from_with_dtype(238fn_name: &syn::Ident,239dtype: syn::Ident,240) -> proc_macro2::TokenStream {241let map_field_name = get_field_function_name(fn_name);242let inputs = quote_get_inputs();243244quote! (245#[no_mangle]246pub unsafe extern "C" fn #map_field_name(247field: *mut polars_arrow::ffi::ArrowSchema,248len: usize,249return_value: *mut polars_arrow::ffi::ArrowSchema250) {251#inputs252253let mapper = polars_plan::prelude::FieldsMapper::new(&inputs);254let dtype = polars_core::datatypes::DataType::#dtype;255let out = mapper.with_dtype(dtype).unwrap();256let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));257*return_value = out;258}259)260}261262#[proc_macro_attribute]263pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream {264let ast = parse_macro_input!(input as syn::ItemFn);265266let options = parse_macro_input!(attr as attr::ExprsFunctionOptions);267let expanded_field_fn = if let Some(fn_name) = options.output_type_fn {268create_field_function(&ast.sig.ident, &fn_name, false)269} else if let Some(fn_name) = options.output_type_fn_kwargs {270create_field_function(&ast.sig.ident, &fn_name, true)271} else if let Some(dtype) = options.output_dtype {272create_field_function_from_with_dtype(&ast.sig.ident, dtype)273} else {274panic!("didn't understand polars_expr attribute")275};276277let expanded_expr = create_expression_function(ast);278let expanded = quote!(279#expanded_field_fn280281#expanded_expr282);283TokenStream::from(expanded)284}285286287