Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/pyo3-polars/pyo3-polars-derive/src/lib.rs
6939 views
1
mod attr;
2
mod keywords;
3
4
use proc_macro::TokenStream;
5
use quote::quote;
6
use syn::{parse_macro_input, FnArg};
7
8
fn quote_get_kwargs() -> proc_macro2::TokenStream {
9
quote!(
10
let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);
11
12
let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
13
Ok(value) => value,
14
Err(err) => {
15
let err = polars_error::polars_err!(InvalidOperation: "could not parse kwargs: '{}'\n\nCheck: registration of kwargs in the plugin.", err);
16
pyo3_polars::derive::_update_last_error(err);
17
return;
18
}
19
};
20
21
)
22
}
23
24
fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
25
let kwargs = quote_get_kwargs();
26
quote!(
27
// parse the kwargs and assign to `let kwargs`
28
#kwargs
29
30
// define the function
31
#ast
32
33
// call the function
34
let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, kwargs);
35
36
)
37
}
38
39
fn quote_call_context(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
40
quote!(
41
let context = *context;
42
43
// define the function
44
#ast
45
46
// call the function
47
let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context);
48
)
49
}
50
51
fn quote_call_context_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
52
quote!(
53
let context = *context;
54
55
let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);
56
57
let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
58
Ok(value) => value,
59
Err(err) => {
60
pyo3_polars::derive::_update_last_error(err);
61
return;
62
}
63
};
64
65
// define the function
66
#ast
67
68
// call the function
69
let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context, kwargs);
70
)
71
}
72
73
fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
74
quote!(
75
// define the function
76
#ast
77
// call the function
78
let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs);
79
)
80
}
81
82
fn quote_process_results() -> proc_macro2::TokenStream {
83
quote!(match result {
84
Ok(out) => {
85
// Update return value.
86
*return_value = polars_ffi::version_0::export_series(&out);
87
},
88
Err(err) => {
89
// Set latest error, but leave return value in empty state.
90
pyo3_polars::derive::_update_last_error(err);
91
},
92
})
93
}
94
95
fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
96
// count how often the user define a kwargs argument.
97
let args = ast
98
.sig
99
.inputs
100
.iter()
101
.skip(1)
102
.map(|fn_arg| {
103
if let FnArg::Typed(pat) = fn_arg {
104
if let syn::Pat::Ident(pat) = pat.pat.as_ref() {
105
pat.ident.to_string()
106
} else {
107
panic!("expected an argument")
108
}
109
} else {
110
panic!("expected a type argument")
111
}
112
})
113
.collect::<Vec<_>>();
114
115
let fn_name = &ast.sig.ident;
116
117
// Get the tokenstream of the call logic.
118
let quote_call = match args.len() {
119
0 => quote_call_no_kwargs(&ast, fn_name),
120
1 => match args[0].as_str() {
121
"kwargs" => quote_call_kwargs(&ast, fn_name),
122
"context" => quote_call_context(&ast, fn_name),
123
a => panic!("didn't expect argument {a}"),
124
},
125
2 => match (args[0].as_str(), args[1].as_str()) {
126
("context", "kwargs") => quote_call_context_kwargs(&ast, fn_name),
127
("kwargs", "context") => panic!("'kwargs', 'context' order should be reversed"),
128
(a, b) => panic!("didn't expect arguments {a}, {b}"),
129
},
130
_ => panic!("didn't expect so many arguments"),
131
};
132
133
let quote_process_result = quote_process_results();
134
let fn_name = get_expression_function_name(fn_name);
135
136
quote!(
137
use ::pyo3_polars::export::*;
138
139
// create the outer public function
140
#[no_mangle]
141
pub unsafe extern "C" fn #fn_name (
142
e: *mut polars_ffi::version_0::SeriesExport,
143
input_len: usize,
144
kwargs_ptr: *const u8,
145
kwargs_len: usize,
146
return_value: *mut polars_ffi::version_0::SeriesExport,
147
context: *mut polars_ffi::version_0::CallerContext
148
) {
149
let panic_result = std::panic::catch_unwind(move || {
150
let inputs = polars_ffi::version_0::import_series_buffer(e, input_len).unwrap();
151
152
#quote_call
153
154
#quote_process_result
155
});
156
157
if panic_result.is_err() {
158
// Set latest to panic;
159
::pyo3_polars::derive::_set_panic();
160
}
161
}
162
)
163
}
164
165
fn get_field_function_name(fn_name: &syn::Ident) -> syn::Ident {
166
syn::Ident::new(&format!("_polars_plugin_field_{fn_name}"), fn_name.span())
167
}
168
169
fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident {
170
syn::Ident::new(&format!("_polars_plugin_{fn_name}"), fn_name.span())
171
}
172
173
fn quote_get_inputs() -> proc_macro2::TokenStream {
174
quote!(
175
let inputs = std::slice::from_raw_parts(field, len);
176
let inputs = inputs.iter().map(|field| {
177
let field = polars_arrow::ffi::import_field_from_c(field).unwrap();
178
let out = polars_core::prelude::Field::from(&field);
179
out
180
}).collect::<Vec<_>>();
181
)
182
}
183
184
fn create_field_function(
185
fn_name: &syn::Ident,
186
dtype_fn_name: &syn::Ident,
187
kwargs: bool,
188
) -> proc_macro2::TokenStream {
189
let map_field_name = get_field_function_name(fn_name);
190
let inputs = quote_get_inputs();
191
192
let call_fn = if kwargs {
193
let kwargs = quote_get_kwargs();
194
quote! (
195
#kwargs
196
let result = #dtype_fn_name(&inputs, kwargs);
197
)
198
} else {
199
quote!(
200
let result = #dtype_fn_name(&inputs);
201
)
202
};
203
204
quote! (
205
#[no_mangle]
206
pub unsafe extern "C" fn #map_field_name(
207
field: *mut polars_arrow::ffi::ArrowSchema,
208
len: usize,
209
return_value: *mut polars_arrow::ffi::ArrowSchema,
210
kwargs_ptr: *const u8,
211
kwargs_len: usize,
212
) {
213
let panic_result = std::panic::catch_unwind(move || {
214
#inputs;
215
216
#call_fn;
217
218
match result {
219
Ok(out) => {
220
let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));
221
*return_value = out;
222
},
223
Err(err) => {
224
// Set latest error, but leave return value in empty state.
225
pyo3_polars::derive::_update_last_error(err);
226
}
227
}
228
});
229
230
if panic_result.is_err() {
231
// Set latest to panic;
232
pyo3_polars::derive::_set_panic();
233
}
234
}
235
)
236
}
237
238
fn create_field_function_from_with_dtype(
239
fn_name: &syn::Ident,
240
dtype: syn::Ident,
241
) -> proc_macro2::TokenStream {
242
let map_field_name = get_field_function_name(fn_name);
243
let inputs = quote_get_inputs();
244
245
quote! (
246
#[no_mangle]
247
pub unsafe extern "C" fn #map_field_name(
248
field: *mut polars_arrow::ffi::ArrowSchema,
249
len: usize,
250
return_value: *mut polars_arrow::ffi::ArrowSchema
251
) {
252
#inputs
253
254
let mapper = polars_plan::prelude::FieldsMapper::new(&inputs);
255
let dtype = polars_core::datatypes::DataType::#dtype;
256
let out = mapper.with_dtype(dtype).unwrap();
257
let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));
258
*return_value = out;
259
}
260
)
261
}
262
263
#[proc_macro_attribute]
264
pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream {
265
let ast = parse_macro_input!(input as syn::ItemFn);
266
267
let options = parse_macro_input!(attr as attr::ExprsFunctionOptions);
268
let expanded_field_fn = if let Some(fn_name) = options.output_type_fn {
269
create_field_function(&ast.sig.ident, &fn_name, false)
270
} else if let Some(fn_name) = options.output_type_fn_kwargs {
271
create_field_function(&ast.sig.ident, &fn_name, true)
272
} else if let Some(dtype) = options.output_dtype {
273
create_field_function_from_with_dtype(&ast.sig.ident, dtype)
274
} else {
275
panic!("didn't understand polars_expr attribute")
276
};
277
278
let expanded_expr = create_expression_function(ast);
279
let expanded = quote!(
280
#expanded_field_fn
281
282
#expanded_expr
283
);
284
TokenStream::from(expanded)
285
}
286
287