Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/python_dsl/python_udf.rs
6940 views
1
use std::io::Cursor;
2
use std::sync::{Arc, OnceLock};
3
4
use polars_core::datatypes::{DataType, Field};
5
use polars_core::error::*;
6
use polars_core::frame::DataFrame;
7
use polars_core::frame::column::Column;
8
use polars_core::schema::Schema;
9
use polars_utils::pl_str::PlSmallStr;
10
use pyo3::prelude::*;
11
12
use crate::dsl::udf::try_infer_udf_output_dtype;
13
use crate::prelude::*;
14
15
// Will be overwritten on Python Polars start up.
16
#[allow(clippy::type_complexity)]
17
pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
18
fn(s: &[Column], output_dtype: Option<DataType>, lambda: &PyObject) -> PolarsResult<Column>,
19
> = None;
20
pub static mut CALL_DF_UDF_PYTHON: Option<
21
fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
22
> = None;
23
24
pub use polars_utils::python_function::PythonFunction;
25
#[cfg(feature = "serde")]
26
pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};
27
28
pub struct PythonUdfExpression {
29
python_function: PyObject,
30
output_type: Option<DataTypeExpr>,
31
materialized_field: OnceLock<Field>,
32
is_elementwise: bool,
33
returns_scalar: bool,
34
}
35
36
impl PythonUdfExpression {
37
pub fn new(
38
lambda: PyObject,
39
output_type: Option<impl Into<DataTypeExpr>>,
40
is_elementwise: bool,
41
returns_scalar: bool,
42
) -> Self {
43
let output_type = output_type.map(Into::into);
44
Self {
45
python_function: lambda,
46
output_type,
47
materialized_field: OnceLock::new(),
48
is_elementwise,
49
returns_scalar,
50
}
51
}
52
53
#[cfg(feature = "serde")]
54
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {
55
use polars_utils::pl_serialize;
56
57
if !buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK) {
58
polars_bail!(InvalidOperation: "serialization expected python magic byte mark");
59
}
60
let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
61
62
// Load UDF metadata
63
let mut reader = Cursor::new(buf);
64
let (output_type, materialized, is_elementwise, returns_scalar): (
65
Option<DataTypeExpr>,
66
Option<Field>,
67
bool,
68
bool,
69
) = pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
70
71
let buf = &buf[reader.position() as usize..];
72
let python_function = pl_serialize::python_object_deserialize(buf)?;
73
74
let mut udf = Self::new(python_function, output_type, is_elementwise, returns_scalar);
75
if let Some(materialized) = materialized {
76
udf.materialized_field = OnceLock::from(materialized);
77
}
78
79
Ok(Arc::new(udf))
80
}
81
}
82
83
impl DataFrameUdf for polars_utils::python_function::PythonFunction {
84
fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
85
let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
86
func(df, &self.0)
87
}
88
}
89
90
impl ColumnsUdf for PythonUdfExpression {
91
fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {
92
let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
93
let field = self
94
.materialized_field
95
.get()
96
.expect("should have been materialized at this point");
97
let mut out = func(
98
s,
99
self.materialized_field.get().map(|f| f.dtype.clone()),
100
&self.python_function,
101
)?;
102
103
let must_cast = out.dtype().matches_schema_type(field.dtype()).map_err(|_| {
104
polars_err!(
105
SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
106
field.dtype(), out.dtype(),
107
)
108
})?;
109
if must_cast {
110
out = out.cast(field.dtype())?;
111
}
112
113
Ok(out)
114
}
115
}
116
117
impl AnonymousColumnsUdf for PythonUdfExpression {
118
fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {
119
self as _
120
}
121
fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {
122
Arc::new(Self {
123
python_function: Python::with_gil(|py| self.python_function.clone_ref(py)),
124
output_type: self.output_type.clone(),
125
materialized_field: OnceLock::new(),
126
is_elementwise: self.is_elementwise,
127
returns_scalar: self.returns_scalar,
128
}) as _
129
}
130
131
#[cfg(feature = "serde")]
132
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
133
use polars_utils::pl_serialize;
134
135
// Write byte marks
136
buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
137
138
// Write UDF metadata
139
pl_serialize::serialize_into_writer::<_, _, true>(
140
&mut *buf,
141
&(
142
self.output_type.clone(),
143
self.materialized_field.get().cloned(),
144
self.is_elementwise,
145
self.returns_scalar,
146
),
147
)?;
148
149
pl_serialize::python_object_serialize(&self.python_function, buf)?;
150
Ok(())
151
}
152
153
fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {
154
let field = match self.materialized_field.get() {
155
Some(f) => f.clone(),
156
None => {
157
let dtype = match self.output_type.as_ref() {
158
None => {
159
let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
160
let f = |s: &[Column]| func(s, None, &self.python_function);
161
try_infer_udf_output_dtype(&f as _, fields)?
162
},
163
Some(output_type) => output_type
164
.clone()
165
.into_datatype_with_self(input_schema, fields[0].dtype())?,
166
};
167
168
// Take the name of first field, just like `map_field`.
169
let name = fields[0].name();
170
let f = Field::new(name.clone(), dtype);
171
self.materialized_field.get_or_init(|| f.clone());
172
f
173
},
174
};
175
Ok(field)
176
}
177
}
178
179
impl Expr {
180
pub fn map_python(self, func: PythonUdfExpression) -> Expr {
181
Self::map_many_python(vec![self], func)
182
}
183
184
pub fn map_many_python(exprs: Vec<Expr>, func: PythonUdfExpression) -> Expr {
185
const NAME: &str = "python_udf";
186
187
let returns_scalar = func.returns_scalar;
188
189
let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
190
if func.is_elementwise {
191
flags.set_elementwise();
192
}
193
if returns_scalar {
194
flags |= FunctionFlags::RETURNS_SCALAR;
195
}
196
197
Expr::AnonymousFunction {
198
input: exprs,
199
function: new_column_udf(func),
200
options: FunctionOptions {
201
flags,
202
..Default::default()
203
},
204
fmt_str: Box::new(PlSmallStr::from(NAME)),
205
}
206
}
207
}
208
209