Path: blob/main/crates/polars-plan/src/dsl/python_dsl/python_udf.rs
6940 views
use std::io::Cursor;1use std::sync::{Arc, OnceLock};23use polars_core::datatypes::{DataType, Field};4use polars_core::error::*;5use polars_core::frame::DataFrame;6use polars_core::frame::column::Column;7use polars_core::schema::Schema;8use polars_utils::pl_str::PlSmallStr;9use pyo3::prelude::*;1011use crate::dsl::udf::try_infer_udf_output_dtype;12use crate::prelude::*;1314// Will be overwritten on Python Polars start up.15#[allow(clippy::type_complexity)]16pub static mut CALL_COLUMNS_UDF_PYTHON: Option<17fn(s: &[Column], output_dtype: Option<DataType>, lambda: &PyObject) -> PolarsResult<Column>,18> = None;19pub static mut CALL_DF_UDF_PYTHON: Option<20fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,21> = None;2223pub use polars_utils::python_function::PythonFunction;24#[cfg(feature = "serde")]25pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};2627pub struct PythonUdfExpression {28python_function: PyObject,29output_type: Option<DataTypeExpr>,30materialized_field: OnceLock<Field>,31is_elementwise: bool,32returns_scalar: bool,33}3435impl PythonUdfExpression {36pub fn new(37lambda: PyObject,38output_type: Option<impl Into<DataTypeExpr>>,39is_elementwise: bool,40returns_scalar: bool,41) -> Self {42let output_type = output_type.map(Into::into);43Self {44python_function: lambda,45output_type,46materialized_field: OnceLock::new(),47is_elementwise,48returns_scalar,49}50}5152#[cfg(feature = "serde")]53pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {54use polars_utils::pl_serialize;5556if !buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK) {57polars_bail!(InvalidOperation: "serialization expected python magic byte mark");58}59let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];6061// Load UDF metadata62let mut reader = Cursor::new(buf);63let (output_type, materialized, is_elementwise, returns_scalar): (64Option<DataTypeExpr>,65Option<Field>,66bool,67bool,68) = pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;6970let buf = &buf[reader.position() as usize..];71let python_function = pl_serialize::python_object_deserialize(buf)?;7273let mut udf = Self::new(python_function, output_type, is_elementwise, returns_scalar);74if let Some(materialized) = materialized {75udf.materialized_field = OnceLock::from(materialized);76}7778Ok(Arc::new(udf))79}80}8182impl DataFrameUdf for polars_utils::python_function::PythonFunction {83fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {84let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };85func(df, &self.0)86}87}8889impl ColumnsUdf for PythonUdfExpression {90fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {91let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };92let field = self93.materialized_field94.get()95.expect("should have been materialized at this point");96let mut out = func(97s,98self.materialized_field.get().map(|f| f.dtype.clone()),99&self.python_function,100)?;101102let must_cast = out.dtype().matches_schema_type(field.dtype()).map_err(|_| {103polars_err!(104SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",105field.dtype(), out.dtype(),106)107})?;108if must_cast {109out = out.cast(field.dtype())?;110}111112Ok(out)113}114}115116impl AnonymousColumnsUdf for PythonUdfExpression {117fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {118self as _119}120fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {121Arc::new(Self {122python_function: Python::with_gil(|py| self.python_function.clone_ref(py)),123output_type: self.output_type.clone(),124materialized_field: OnceLock::new(),125is_elementwise: self.is_elementwise,126returns_scalar: self.returns_scalar,127}) as _128}129130#[cfg(feature = "serde")]131fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {132use polars_utils::pl_serialize;133134// Write byte marks135buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);136137// Write UDF metadata138pl_serialize::serialize_into_writer::<_, _, true>(139&mut *buf,140&(141self.output_type.clone(),142self.materialized_field.get().cloned(),143self.is_elementwise,144self.returns_scalar,145),146)?;147148pl_serialize::python_object_serialize(&self.python_function, buf)?;149Ok(())150}151152fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {153let field = match self.materialized_field.get() {154Some(f) => f.clone(),155None => {156let dtype = match self.output_type.as_ref() {157None => {158let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };159let f = |s: &[Column]| func(s, None, &self.python_function);160try_infer_udf_output_dtype(&f as _, fields)?161},162Some(output_type) => output_type163.clone()164.into_datatype_with_self(input_schema, fields[0].dtype())?,165};166167// Take the name of first field, just like `map_field`.168let name = fields[0].name();169let f = Field::new(name.clone(), dtype);170self.materialized_field.get_or_init(|| f.clone());171f172},173};174Ok(field)175}176}177178impl Expr {179pub fn map_python(self, func: PythonUdfExpression) -> Expr {180Self::map_many_python(vec![self], func)181}182183pub fn map_many_python(exprs: Vec<Expr>, func: PythonUdfExpression) -> Expr {184const NAME: &str = "python_udf";185186let returns_scalar = func.returns_scalar;187188let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;189if func.is_elementwise {190flags.set_elementwise();191}192if returns_scalar {193flags |= FunctionFlags::RETURNS_SCALAR;194}195196Expr::AnonymousFunction {197input: exprs,198function: new_column_udf(func),199options: FunctionOptions {200flags,201..Default::default()202},203fmt_str: Box::new(PlSmallStr::from(NAME)),204}205}206}207208209