Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/python_dsl/python_udf.rs
8341 views
1
use std::io::Cursor;
2
use std::sync::{Arc, OnceLock};
3
4
use polars_core::datatypes::{DataType, Field};
5
use polars_core::error::*;
6
use polars_core::frame::DataFrame;
7
use polars_core::frame::column::Column;
8
use polars_core::schema::Schema;
9
use polars_utils::pl_str::PlSmallStr;
10
use pyo3::prelude::*;
11
12
use crate::dsl::udf::try_infer_udf_output_dtype;
13
use crate::prelude::*;
14
15
// Will be overwritten on Python Polars start up.
16
#[allow(clippy::type_complexity)]
17
pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
18
fn(s: &[Column], output_dtype: Option<DataType>, lambda: &Py<PyAny>) -> PolarsResult<Column>,
19
> = None;
20
21
#[allow(clippy::type_complexity)]
22
pub static mut CALL_DF_UDF_PYTHON: Option<
23
fn(s: DataFrame, lambda: &Py<PyAny>) -> PolarsResult<DataFrame>,
24
> = None;
25
26
pub use polars_utils::python_function::PythonFunction;
27
#[cfg(feature = "serde")]
28
pub use polars_utils::python_function::{PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON3_VERSION};
29
30
pub struct PythonUdfExpression {
31
python_function: Py<PyAny>,
32
output_type: Option<DataTypeExpr>,
33
materialized_field: OnceLock<Field>,
34
is_elementwise: bool,
35
returns_scalar: bool,
36
}
37
38
impl PythonUdfExpression {
39
pub fn new(
40
lambda: Py<PyAny>,
41
output_type: Option<impl Into<DataTypeExpr>>,
42
is_elementwise: bool,
43
returns_scalar: bool,
44
) -> Self {
45
let output_type = output_type.map(Into::into);
46
Self {
47
python_function: lambda,
48
output_type,
49
materialized_field: OnceLock::new(),
50
is_elementwise,
51
returns_scalar,
52
}
53
}
54
55
#[cfg(feature = "serde")]
56
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {
57
use polars_utils::pl_serialize;
58
59
if !buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK) {
60
polars_bail!(InvalidOperation: "serialization expected python magic byte mark");
61
}
62
let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..];
63
64
// Load UDF metadata
65
let mut reader = Cursor::new(buf);
66
let (output_type, materialized, is_elementwise, returns_scalar): (
67
Option<DataTypeExpr>,
68
Option<Field>,
69
bool,
70
bool,
71
) = pl_serialize::deserialize_from_reader::<_, _, true>(&mut reader)?;
72
73
let buf = &buf[reader.position() as usize..];
74
let python_function = pl_serialize::python_object_deserialize(buf)?;
75
76
let mut udf = Self::new(python_function, output_type, is_elementwise, returns_scalar);
77
if let Some(materialized) = materialized {
78
udf.materialized_field = OnceLock::from(materialized);
79
}
80
81
Ok(Arc::new(udf))
82
}
83
}
84
85
impl DataFrameUdf for polars_utils::python_function::PythonFunction {
86
fn call_udf(&self, df: DataFrame) -> PolarsResult<DataFrame> {
87
let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() };
88
func(df, &self.0)
89
}
90
91
fn display_str(&self) -> PlSmallStr {
92
pyo3::Python::attach(|py| {
93
use polars_utils::format_pl_smallstr;
94
use pyo3::intern;
95
use pyo3::pybacked::PyBackedStr;
96
97
let class_name: PyBackedStr = self
98
.0
99
.getattr(py, intern!(py, "__class__"))
100
.unwrap()
101
.extract(py)
102
.unwrap();
103
104
format_pl_smallstr!("PythonUdf({class_name})")
105
})
106
}
107
}
108
109
impl ColumnsUdf for PythonUdfExpression {
110
fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {
111
let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
112
let field = self
113
.materialized_field
114
.get()
115
.expect("should have been materialized at this point");
116
let mut out = func(
117
s,
118
self.materialized_field.get().map(|f| f.dtype.clone()),
119
&self.python_function,
120
)?;
121
122
let must_cast = out.dtype().matches_schema_type(field.dtype()).map_err(|_| {
123
polars_err!(
124
SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
125
field.dtype(), out.dtype(),
126
)
127
})?;
128
if must_cast {
129
out = out.cast(field.dtype())?;
130
}
131
132
Ok(out)
133
}
134
}
135
136
impl AnonymousColumnsUdf for PythonUdfExpression {
137
fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {
138
self as _
139
}
140
fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {
141
Arc::new(Self {
142
python_function: Python::attach(|py| self.python_function.clone_ref(py)),
143
output_type: self.output_type.clone(),
144
materialized_field: OnceLock::new(),
145
is_elementwise: self.is_elementwise,
146
returns_scalar: self.returns_scalar,
147
}) as _
148
}
149
150
#[cfg(feature = "serde")]
151
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
152
use polars_utils::pl_serialize;
153
154
// Write byte marks
155
buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK);
156
157
// Write UDF metadata
158
pl_serialize::serialize_into_writer::<_, _, true>(
159
&mut *buf,
160
&(
161
self.output_type.clone(),
162
self.materialized_field.get().cloned(),
163
self.is_elementwise,
164
self.returns_scalar,
165
),
166
)?;
167
168
pl_serialize::python_object_serialize(&self.python_function, buf)?;
169
Ok(())
170
}
171
172
fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {
173
let field = match self.materialized_field.get() {
174
Some(f) => f.clone(),
175
None => {
176
let dtype = match self.output_type.as_ref() {
177
None => {
178
let func = unsafe { CALL_COLUMNS_UDF_PYTHON.unwrap() };
179
let f = |s: &[Column]| func(s, None, &self.python_function);
180
try_infer_udf_output_dtype(&f as _, fields)?
181
},
182
Some(output_type) => output_type
183
.clone()
184
.into_datatype_with_self(input_schema, fields[0].dtype())?,
185
};
186
187
// Take the name of first field, just like `map_field`.
188
let name = fields[0].name();
189
let f = Field::new(name.clone(), dtype);
190
self.materialized_field.get_or_init(|| f.clone());
191
f
192
},
193
};
194
Ok(field)
195
}
196
}
197
198
impl Expr {
199
pub fn map_python(self, func: PythonUdfExpression) -> Expr {
200
Self::map_many_python(vec![self], func)
201
}
202
203
pub fn map_many_python(exprs: Vec<Expr>, func: PythonUdfExpression) -> Expr {
204
const NAME: &str = "python_udf";
205
206
let returns_scalar = func.returns_scalar;
207
208
let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
209
if func.is_elementwise {
210
flags.set_elementwise();
211
}
212
if returns_scalar {
213
flags |= FunctionFlags::RETURNS_SCALAR;
214
}
215
216
Expr::AnonymousFunction {
217
input: exprs,
218
function: new_column_udf(func),
219
options: FunctionOptions {
220
flags,
221
..Default::default()
222
},
223
fmt_str: Box::new(PlSmallStr::from(NAME)),
224
}
225
}
226
}
227
228