Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/plugin.rs
7889 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::ffi::CStr;
3
use std::sync::{LazyLock, RwLock};
4
5
use arrow::ffi::{ArrowSchema, import_field_from_c};
6
use libloading::Library;
7
#[cfg(feature = "python")]
8
use pyo3::{Python, types::PyAnyMethods};
9
10
use super::*;
11
12
type PluginAndVersion = (Library, u16, u16);
13
static LOADED: LazyLock<RwLock<PlHashMap<String, PluginAndVersion>>> =
14
LazyLock::new(Default::default);
15
16
fn get_lib(lib: &str) -> PolarsResult<&'static PluginAndVersion> {
17
let lib_map = LOADED.read().unwrap();
18
if let Some(library) = lib_map.get(lib) {
19
// lifetime is static as we never remove libraries.
20
Ok(unsafe { std::mem::transmute::<&PluginAndVersion, &'static PluginAndVersion>(library) })
21
} else {
22
drop(lib_map);
23
24
#[cfg(feature = "python")]
25
let load_path = if !std::path::Path::new(lib).is_absolute() {
26
// Get python virtual environment path
27
let prefix = Python::attach(|py| {
28
let sys = py.import("sys").unwrap();
29
let prefix = sys.getattr("prefix").unwrap();
30
prefix.to_string()
31
});
32
let full_path = std::path::Path::new(&prefix).join(lib);
33
full_path.to_string_lossy().into_owned()
34
} else {
35
lib.to_string()
36
};
37
#[cfg(not(feature = "python"))]
38
let load_path = lib.to_string();
39
40
let library = unsafe {
41
Library::new(&load_path).map_err(|e| {
42
PolarsError::ComputeError(format!("error loading dynamic library: {e}").into())
43
})?
44
};
45
let version_function: libloading::Symbol<unsafe extern "C" fn() -> u32> = unsafe {
46
library
47
.get("_polars_plugin_get_version".as_bytes())
48
.unwrap()
49
};
50
51
let version = unsafe { version_function() };
52
let major = (version >> 16) as u16;
53
let minor = version as u16;
54
55
let mut lib_map = LOADED.write().unwrap();
56
lib_map.insert(lib.to_string(), (library, major, minor));
57
drop(lib_map);
58
59
get_lib(lib)
60
}
61
}
62
63
fn retrieve_error_msg(lib: &Library) -> String {
64
unsafe {
65
// SAFETY: _polars_plugin_get_last_error_message returns data stored
66
// in a null-terminated thread-local so we immediately clone it.
67
let symbol: libloading::Symbol<unsafe extern "C" fn() -> *mut std::os::raw::c_char> =
68
lib.get(b"_polars_plugin_get_last_error_message\0").unwrap();
69
let msg_ptr = symbol();
70
CStr::from_ptr(msg_ptr).to_string_lossy().into_owned()
71
}
72
}
73
74
#[doc(hidden)]
75
pub unsafe fn call_plugin(
76
s: &[Column],
77
lib: &str,
78
symbol: &str,
79
kwargs: &[u8],
80
) -> PolarsResult<Column> {
81
let plugin = get_lib(lib)?;
82
let lib = &plugin.0;
83
let major = plugin.1;
84
85
if major == 0 {
86
use polars_ffi::version_0::*;
87
// *const SeriesExport: pointer to Box<SeriesExport>
88
// * usize: length of that pointer
89
// *const u8: pointer to &[u8]
90
// usize: length of the u8 slice
91
// *mut SeriesExport: pointer where return value should be written.
92
// *const CallerContext
93
let symbol: libloading::Symbol<
94
unsafe extern "C" fn(
95
*const SeriesExport,
96
usize,
97
*const u8,
98
usize,
99
*mut SeriesExport,
100
*const CallerContext,
101
),
102
> = lib
103
.get(format!("_polars_plugin_{symbol}").as_bytes())
104
.unwrap();
105
106
// @scalar-correctness?
107
let input = s.iter().map(export_column).collect::<Vec<_>>();
108
let input_len = s.len();
109
let slice_ptr = input.as_ptr();
110
111
let kwargs_ptr = kwargs.as_ptr();
112
let kwargs_len = kwargs.len();
113
114
let mut return_value = SeriesExport::empty();
115
let return_value_ptr = &mut return_value as *mut SeriesExport;
116
let context = CallerContext::default();
117
let context_ptr = &context as *const CallerContext;
118
symbol(
119
slice_ptr,
120
input_len,
121
kwargs_ptr,
122
kwargs_len,
123
return_value_ptr,
124
context_ptr,
125
);
126
127
// The inputs get dropped when the ffi side calls the drop callback.
128
for e in input {
129
std::mem::forget(e);
130
}
131
132
if !return_value.is_null() {
133
import_series(return_value).map(Column::from)
134
} else {
135
let msg = retrieve_error_msg(lib);
136
check_panic(msg.as_ref())?;
137
polars_bail!(ComputeError: "the plugin failed with message: {}", msg)
138
}
139
} else {
140
polars_bail!(ComputeError: "this polars engine doesn't support plugin version: {}", major)
141
}
142
}
143
144
pub(super) unsafe fn plugin_field(
145
fields: &[Field],
146
lib: &str,
147
symbol: &str,
148
kwargs: &[u8],
149
) -> PolarsResult<Field> {
150
let plugin = get_lib(lib)?;
151
let lib = &plugin.0;
152
let major = plugin.1;
153
let minor = plugin.2;
154
155
// we deallocate the fields buffer
156
let ffi_fields = fields
157
.iter()
158
.map(|field| arrow::ffi::export_field_to_c(&field.to_arrow(CompatLevel::newest())))
159
.collect::<Vec<_>>()
160
.into_boxed_slice();
161
let n_args = ffi_fields.len();
162
let slice_ptr = ffi_fields.as_ptr();
163
164
let mut return_value = ArrowSchema::empty();
165
let return_value_ptr = &mut return_value as *mut ArrowSchema;
166
167
if major == 0 {
168
match minor {
169
0 => {
170
let views = fields.iter().any(|field| field.dtype.contains_views());
171
polars_ensure!(!views, ComputeError: "cannot call plugin\n\nThis Polars' version has a different 'binary/string' layout. Please compile with latest 'pyo3-polars'");
172
173
// *const ArrowSchema: pointer to heap Box<ArrowSchema>
174
// usize: length of the boxed slice
175
// *mut ArrowSchema: pointer where the return value can be written
176
let symbol: libloading::Symbol<
177
unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema),
178
> = lib
179
.get((format!("_polars_plugin_field_{symbol}")).as_bytes())
180
.unwrap();
181
symbol(slice_ptr, n_args, return_value_ptr);
182
},
183
1 => {
184
// *const ArrowSchema: pointer to heap Box<ArrowSchema>
185
// usize: length of the boxed slice
186
// *mut ArrowSchema: pointer where the return value can be written
187
// *const u8: pointer to &[u8] (kwargs)
188
// usize: length of the u8 slice
189
let symbol: libloading::Symbol<
190
unsafe extern "C" fn(
191
*const ArrowSchema,
192
usize,
193
*mut ArrowSchema,
194
*const u8,
195
usize,
196
),
197
> = lib
198
.get((format!("_polars_plugin_field_{symbol}")).as_bytes())
199
.unwrap();
200
201
let kwargs_ptr = kwargs.as_ptr();
202
let kwargs_len = kwargs.len();
203
204
symbol(slice_ptr, n_args, return_value_ptr, kwargs_ptr, kwargs_len);
205
},
206
_ => {
207
polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}-{}", major, minor)
208
},
209
}
210
if !return_value.is_null() {
211
let arrow_field = import_field_from_c(&return_value)?;
212
let out = Field::from(&arrow_field);
213
Ok(out)
214
} else {
215
let msg = retrieve_error_msg(lib);
216
check_panic(msg.as_ref())?;
217
polars_bail!(ComputeError: "the plugin failed with message: {}", msg)
218
}
219
} else {
220
polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}", major)
221
}
222
}
223
224
fn check_panic(msg: &str) -> PolarsResult<()> {
225
polars_ensure!(msg != "PANIC", ComputeError: "the plugin panicked\n\nThe message is suppressed. Set POLARS_VERBOSE=1 to send the panic message to stderr.");
226
Ok(())
227
}
228
229