Path: blob/main/crates/polars-plan/src/dsl/python_dsl/python_udf.rs
8341 views
use std::io::Cursor;1use std::sync::{Arc, OnceLock};23use polars_core::datatypes::{DataType, Field};4use polars_core::error::*;5use polars_core::frame::DataFrame;6use polars_core::frame::column::Column;7use polars_core::schema::Schema;8use polars_utils::pl_str::PlSmallStr;9use pyo3::prelude::*;1011use crate::dsl::udf::try_infer_udf_output_dtype;12use crate::prelude::*;1314// Will be overwritten on Python Polars start up.15#[allow(clippy::type_complexity)]16pub static mut CALL_COLUMNS_UDF_PYTHON: Option<17fn(s: &[Column], output_dtype: Option<DataType>, lambda: &Py<PyAny>) -> PolarsResult<Column>,18> = None;1920#[allow(clippy::type_complexity)]21pub static mut CALL_DF_UDF_PYTHON: Option<22fn(s: DataFrame, lambda: &Py<PyAny>) -> PolarsResult<DataFrame>,23> = None;2425pub use polars_utils::python_function::PythonFunction;26#[cfg(feature = "serde")]27pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};2829pub struct PythonUdfExpression {30python_function: Py<PyAny>,31output_type: Option<DataTypeExpr>,32materialized_field: OnceLock<Field>,33is_elementwise: bool,34returns_scalar: bool,35}3637impl PythonUdfExpression {38pub fn new(39lambda: Py<PyAny>,40output_type: Option<impl Into<DataTypeExpr>>,41is_elementwise: bool,42returns_scalar: bool,43) -> Self {44let output_type = output_type.map(Into::into);45Self {46python_function: lambda,47output_type,48materialized_field: OnceLock::new(),49is_elementwise,50returns_scalar,51}52}5354#[cfg(feature = "serde")]55pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {56use polars_utils::pl_serialize;5758if !buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK) {59polars_bail!(InvalidOperation: "serialization expected python magic byte mark");60}61let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];6263// Load UDF metadata64let mut reader = Cursor::new(buf);65let (output_type, materialized, is_elementwise, returns_scalar): (66Option<DataTypeExpr>,67Option<Field>,68bool,69bool,70) = pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;7172let buf = &buf[reader.position() as usize..];73let python_function = pl_serialize::python_object_deserialize(buf)?;7475let mut udf = Self::new(python_function, output_type, is_elementwise, returns_scalar);76if let Some(materialized) = materialized {77udf.materialized_field = OnceLock::from(materialized);78}7980Ok(Arc::new(udf))81}82}8384impl DataFrameUdf for polars_utils::python_function::PythonFunction {85fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {86let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };87func(df, &self.0)88}8990fn display_str(&self) -> PlSmallStr {91pyo3::Python::attach(|py| {92use polars_utils::format_pl_smallstr;93use pyo3::intern;94use pyo3::pybacked::PyBackedStr;9596let class_name: PyBackedStr = self97.098.getattr(py, intern!(py, "__class__"))99.unwrap()100.extract(py)101.unwrap();102103format_pl_smallstr!("PythonUdf({class_name})")104})105}106}107108impl ColumnsUdf for PythonUdfExpression {109fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {110let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };111let field = self112.materialized_field113.get()114.expect("should have been materialized at this point");115let mut out = func(116s,117self.materialized_field.get().map(|f| f.dtype.clone()),118&self.python_function,119)?;120121let must_cast = out.dtype().matches_schema_type(field.dtype()).map_err(|_| {122polars_err!(123SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",124field.dtype(), out.dtype(),125)126})?;127if must_cast {128out = out.cast(field.dtype())?;129}130131Ok(out)132}133}134135impl AnonymousColumnsUdf for PythonUdfExpression {136fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {137self as _138}139fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {140Arc::new(Self {141python_function: Python::attach(|py| self.python_function.clone_ref(py)),142output_type: self.output_type.clone(),143materialized_field: OnceLock::new(),144is_elementwise: self.is_elementwise,145returns_scalar: self.returns_scalar,146}) as _147}148149#[cfg(feature = "serde")]150fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {151use polars_utils::pl_serialize;152153// Write byte marks154buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);155156// Write UDF metadata157pl_serialize::serialize_into_writer::<_, _, true>(158&mut *buf,159&(160self.output_type.clone(),161self.materialized_field.get().cloned(),162self.is_elementwise,163self.returns_scalar,164),165)?;166167pl_serialize::python_object_serialize(&self.python_function, buf)?;168Ok(())169}170171fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {172let field = match self.materialized_field.get() {173Some(f) => f.clone(),174None => {175let dtype = match self.output_type.as_ref() {176None => {177let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };178let f = |s: &[Column]| func(s, None, &self.python_function);179try_infer_udf_output_dtype(&f as _, fields)?180},181Some(output_type) => output_type182.clone()183.into_datatype_with_self(input_schema, fields[0].dtype())?,184};185186// Take the name of first field, just like `map_field`.187let name = fields[0].name();188let f = Field::new(name.clone(), dtype);189self.materialized_field.get_or_init(|| f.clone());190f191},192};193Ok(field)194}195}196197impl Expr {198pub fn map_python(self, func: PythonUdfExpression) -> Expr {199Self::map_many_python(vec![self], func)200}201202pub fn map_many_python(exprs: Vec<Expr>, func: PythonUdfExpression) -> Expr {203const NAME: &str = "python_udf";204205let returns_scalar = func.returns_scalar;206207let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;208if func.is_elementwise {209flags.set_elementwise();210}211if returns_scalar {212flags |= FunctionFlags::RETURNS_SCALAR;213}214215Expr::AnonymousFunction {216input: exprs,217function: new_column_udf(func),218options: FunctionOptions {219flags,220..Default::default()221},222fmt_str: Box::new(PlSmallStr::from(NAME)),223}224}225}226227228