Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/plugin.rs
7889 views
#![allow(unsafe_op_in_unsafe_fn)]1use std::ffi::CStr;2use std::sync::{LazyLock, RwLock};34use arrow::ffi::{ArrowSchema, import_field_from_c};5use libloading::Library;6#[cfg(feature = "python")]7use pyo3::{Python, types::PyAnyMethods};89use super::*;1011type PluginAndVersion = (Library, u16, u16);12static LOADED: LazyLock<RwLock<PlHashMap<String, PluginAndVersion>>> =13LazyLock::new(Default::default);1415fn get_lib(lib: &str) -> PolarsResult<&'static PluginAndVersion> {16let lib_map = LOADED.read().unwrap();17if let Some(library) = lib_map.get(lib) {18// lifetime is static as we never remove libraries.19Ok(unsafe { std::mem::transmute::<&PluginAndVersion, &'static PluginAndVersion>(library) })20} else {21drop(lib_map);2223#[cfg(feature = "python")]24let load_path = if !std::path::Path::new(lib).is_absolute() {25// Get python virtual environment path26let prefix = Python::attach(|py| {27let sys = py.import("sys").unwrap();28let prefix = sys.getattr("prefix").unwrap();29prefix.to_string()30});31let full_path = std::path::Path::new(&prefix).join(lib);32full_path.to_string_lossy().into_owned()33} else {34lib.to_string()35};36#[cfg(not(feature = "python"))]37let load_path = lib.to_string();3839let library = unsafe {40Library::new(&load_path).map_err(|e| {41PolarsError::ComputeError(format!("error loading dynamic library: {e}").into())42})?43};44let version_function: libloading::Symbol<unsafe extern "C" fn() -> u32> = unsafe {45library46.get("_polars_plugin_get_version".as_bytes())47.unwrap()48};4950let version = unsafe { version_function() };51let major = (version >> 16) as u16;52let minor = version as u16;5354let mut lib_map = LOADED.write().unwrap();55lib_map.insert(lib.to_string(), (library, major, minor));56drop(lib_map);5758get_lib(lib)59}60}6162fn retrieve_error_msg(lib: &Library) -> String {63unsafe {64// SAFETY: _polars_plugin_get_last_error_message returns data stored65// in a null-terminated thread-local so we immediately clone it.66let symbol: libloading::Symbol<unsafe extern "C" fn() -> *mut std::os::raw::c_char> =67lib.get(b"_polars_plugin_get_last_error_message\0").unwrap();68let msg_ptr = symbol();69CStr::from_ptr(msg_ptr).to_string_lossy().into_owned()70}71}7273#[doc(hidden)]74pub unsafe fn call_plugin(75s: &[Column],76lib: &str,77symbol: &str,78kwargs: &[u8],79) -> PolarsResult<Column> {80let plugin = get_lib(lib)?;81let lib = &plugin.0;82let major = plugin.1;8384if major == 0 {85use polars_ffi::version_0::*;86// *const SeriesExport: pointer to Box<SeriesExport>87// * usize: length of that pointer88// *const u8: pointer to &[u8]89// usize: length of the u8 slice90// *mut SeriesExport: pointer where return value should be written.91// *const CallerContext92let symbol: libloading::Symbol<93unsafe extern "C" fn(94*const SeriesExport,95usize,96*const u8,97usize,98*mut SeriesExport,99*const CallerContext,100),101> = lib102.get(format!("_polars_plugin_{symbol}").as_bytes())103.unwrap();104105// @scalar-correctness?106let input = s.iter().map(export_column).collect::<Vec<_>>();107let input_len = s.len();108let slice_ptr = input.as_ptr();109110let kwargs_ptr = kwargs.as_ptr();111let kwargs_len = kwargs.len();112113let mut return_value = SeriesExport::empty();114let return_value_ptr = &mut return_value as *mut SeriesExport;115let context = CallerContext::default();116let context_ptr = &context as *const CallerContext;117symbol(118slice_ptr,119input_len,120kwargs_ptr,121kwargs_len,122return_value_ptr,123context_ptr,124);125126// The inputs get dropped when the ffi side calls the drop callback.127for e in input {128std::mem::forget(e);129}130131if !return_value.is_null() {132import_series(return_value).map(Column::from)133} else {134let msg = retrieve_error_msg(lib);135check_panic(msg.as_ref())?;136polars_bail!(ComputeError: "the plugin failed with message: {}", msg)137}138} else {139polars_bail!(ComputeError: "this polars engine doesn't support plugin version: {}", major)140}141}142143pub(super) unsafe fn plugin_field(144fields: &[Field],145lib: &str,146symbol: &str,147kwargs: &[u8],148) -> PolarsResult<Field> {149let plugin = get_lib(lib)?;150let lib = &plugin.0;151let major = plugin.1;152let minor = plugin.2;153154// we deallocate the fields buffer155let ffi_fields = fields156.iter()157.map(|field| arrow::ffi::export_field_to_c(&field.to_arrow(CompatLevel::newest())))158.collect::<Vec<_>>()159.into_boxed_slice();160let n_args = ffi_fields.len();161let slice_ptr = ffi_fields.as_ptr();162163let mut return_value = ArrowSchema::empty();164let return_value_ptr = &mut return_value as *mut ArrowSchema;165166if major == 0 {167match minor {1680 => {169let views = fields.iter().any(|field| field.dtype.contains_views());170polars_ensure!(!views, ComputeError: "cannot call plugin\n\nThis Polars' version has a different 'binary/string' layout. Please compile with latest 'pyo3-polars'");171172// *const ArrowSchema: pointer to heap Box<ArrowSchema>173// usize: length of the boxed slice174// *mut ArrowSchema: pointer where the return value can be written175let symbol: libloading::Symbol<176unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema),177> = lib178.get((format!("_polars_plugin_field_{symbol}")).as_bytes())179.unwrap();180symbol(slice_ptr, n_args, return_value_ptr);181},1821 => {183// *const ArrowSchema: pointer to heap Box<ArrowSchema>184// usize: length of the boxed slice185// *mut ArrowSchema: pointer where the return value can be written186// *const u8: pointer to &[u8] (kwargs)187// usize: length of the u8 slice188let symbol: libloading::Symbol<189unsafe extern "C" fn(190*const ArrowSchema,191usize,192*mut ArrowSchema,193*const u8,194usize,195),196> = lib197.get((format!("_polars_plugin_field_{symbol}")).as_bytes())198.unwrap();199200let kwargs_ptr = kwargs.as_ptr();201let kwargs_len = kwargs.len();202203symbol(slice_ptr, n_args, return_value_ptr, kwargs_ptr, kwargs_len);204},205_ => {206polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}-{}", major, minor)207},208}209if !return_value.is_null() {210let arrow_field = import_field_from_c(&return_value)?;211let out = Field::from(&arrow_field);212Ok(out)213} else {214let msg = retrieve_error_msg(lib);215check_panic(msg.as_ref())?;216polars_bail!(ComputeError: "the plugin failed with message: {}", msg)217}218} else {219polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}", major)220}221}222223fn check_panic(msg: &str) -> PolarsResult<()> {224polars_ensure!(msg != "PANIC", ComputeError: "the plugin panicked\n\nThe message is suppressed. Set POLARS_VERBOSE=1 to send the panic message to stderr.");225Ok(())226}227228229