Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/callback.rs
8409 views
1
use std::fmt;
2
use std::sync::Arc;
3
4
use polars_error::PolarsResult;
5
#[cfg(feature = "python")]
6
use polars_utils::python_function::PythonObject;
7
8
use crate::dsl::SpecialEq;
9
10
#[derive(strum_macros::IntoStaticStr)]
11
pub enum PlanCallback<Args, Out> {
12
#[cfg(feature = "python")]
13
Python(SpecialEq<Arc<polars_utils::python_function::PythonFunction>>),
14
Rust(SpecialEq<Arc<dyn Fn(Args) -> PolarsResult<Out> + Send + Sync>>),
15
}
16
17
impl<Args, Out> PartialEq for PlanCallback<Args, Out> {
18
fn eq(&self, other: &Self) -> bool {
19
use PlanCallback as C;
20
21
match (self, other) {
22
#[cfg(feature = "python")]
23
(C::Python(l), C::Python(r)) => SpecialEq::eq(l, r) || PythonObject::eq(l, r),
24
(C::Rust(l), C::Rust(r)) => l.eq(r),
25
_ => false,
26
}
27
}
28
}
29
30
impl<Args, Out> Eq for PlanCallback<Args, Out> {}
31
32
impl<Args, Out> fmt::Debug for PlanCallback<Args, Out> {
33
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
34
f.write_str("PlanCallback::")?;
35
f.write_str(<&'static str>::from(self))
36
}
37
}
38
39
#[cfg(feature = "serde")]
40
impl<Args, Out> serde::Serialize for PlanCallback<Args, Out> {
41
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
42
where
43
S: serde::Serializer,
44
{
45
use serde::ser::Error;
46
47
#[cfg(feature = "python")]
48
if let Self::Python(v) = self {
49
return v.serialize(_serializer);
50
}
51
52
Err(S::Error::custom(format!(
53
"cannot serialize 'opaque' function in {self:?}"
54
)))
55
}
56
}
57
58
#[cfg(feature = "serde")]
59
impl<'de, Args, Out> serde::Deserialize<'de> for PlanCallback<Args, Out> {
60
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
61
where
62
D: serde::Deserializer<'de>,
63
{
64
#[cfg(feature = "python")]
65
{
66
Ok(Self::Python(SpecialEq::new(Arc::new(
67
polars_utils::python_function::PythonFunction::deserialize(_deserializer)?,
68
))))
69
}
70
#[cfg(not(feature = "python"))]
71
{
72
use serde::de::Error;
73
Err(D::Error::custom("cannot deserialize PlanCallback"))
74
}
75
}
76
}
77
78
#[cfg(feature = "dsl-schema")]
79
impl<Args, Out> schemars::JsonSchema for PlanCallback<Args, Out> {
80
fn schema_name() -> std::borrow::Cow<'static, str> {
81
"PlanCallback".into()
82
}
83
84
fn schema_id() -> std::borrow::Cow<'static, str> {
85
std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PlanCallback"))
86
}
87
88
fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
89
Vec::<u8>::json_schema(generator)
90
}
91
}
92
93
impl<Args, Out> std::hash::Hash for PlanCallback<Args, Out> {
94
fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {
95
// no-op.
96
}
97
}
98
99
impl<Args, Out> Clone for PlanCallback<Args, Out> {
100
fn clone(&self) -> Self {
101
match self {
102
#[cfg(feature = "python")]
103
Self::Python(p) => Self::Python(p.clone()),
104
Self::Rust(f) => Self::Rust(f.clone()),
105
}
106
}
107
}
108
109
pub trait PlanCallbackArgs {
110
#[cfg(feature = "python")]
111
fn into_pyany<'py>(self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>>;
112
}
113
pub trait PlanCallbackOut: Sized {
114
#[cfg(feature = "python")]
115
fn from_pyany<'py>(pyany: pyo3::Py<pyo3::PyAny>, py: pyo3::Python<'py>)
116
-> pyo3::PyResult<Self>;
117
}
118
119
#[cfg(feature = "python")]
120
mod _python {
121
use std::sync::Arc;
122
123
use polars_utils::pl_str::PlSmallStr;
124
use pyo3::types::{PyAnyMethods, PyList, PyTuple};
125
use pyo3::*;
126
127
macro_rules! impl_pycb_type {
128
($($type:ty),+) => {
129
$(
130
impl super::PlanCallbackArgs for $type {
131
fn into_pyany<'py>(self, py: Python<'py>) -> PyResult<Py<PyAny>> {
132
Ok(self.into_pyobject(py)?.into_any().unbind())
133
}
134
}
135
136
impl super::PlanCallbackOut for $type {
137
fn from_pyany<'py>(pyany: Py<PyAny>, py: Python<'py>) -> PyResult<Self> {
138
pyany.bind(py).extract::<Self>()
139
}
140
}
141
)+
142
};
143
}
144
145
macro_rules! impl_pycb_type_to_from {
146
($($type:ty => $transformed:ty),+) => {
147
$(
148
impl super::PlanCallbackArgs for $type {
149
fn into_pyany<'py>(self, py: Python<'py>) -> PyResult<Py<PyAny>> {
150
Ok(<$transformed>::from(self).into_pyobject(py)?.into_any().unbind())
151
}
152
}
153
154
impl super::PlanCallbackOut for $type {
155
fn from_pyany<'py>(pyany: Py<PyAny>, py: Python<'py>) -> PyResult<Self> {
156
pyany.bind(py).extract::<$transformed>().map(Into::into)
157
}
158
}
159
)+
160
};
161
}
162
163
macro_rules! impl_registrycb_type {
164
($(($type:path, $from:ident, $to:ident)),+) => {
165
$(
166
impl super::PlanCallbackArgs for $type {
167
fn into_pyany<'py>(self, _py: Python<'py>) -> PyResult<Py<PyAny>> {
168
let registry = polars_utils::python_convert_registry::get_python_convert_registry();
169
(registry.to_py.$to)(&self)
170
}
171
}
172
173
impl super::PlanCallbackOut for $type {
174
fn from_pyany<'py>(pyany: Py<PyAny>, _py: Python<'py>) -> PyResult<Self> {
175
let registry = polars_utils::python_convert_registry::get_python_convert_registry();
176
let obj = (registry.from_py.$from)(pyany)?;
177
let obj = obj.downcast().unwrap();
178
Ok(*obj)
179
}
180
}
181
)+
182
};
183
}
184
185
impl<T: super::PlanCallbackArgs> super::PlanCallbackArgs for Option<T> {
186
fn into_pyany<'py>(self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
187
match self {
188
None => Ok(py.None()),
189
Some(v) => v.into_pyany(py),
190
}
191
}
192
}
193
194
impl<T: super::PlanCallbackOut> super::PlanCallbackOut for Option<T> {
195
fn from_pyany<'py>(
196
pyany: pyo3::Py<pyo3::PyAny>,
197
py: pyo3::Python<'py>,
198
) -> pyo3::PyResult<Self> {
199
if pyany.is_none(py) {
200
Ok(None)
201
} else {
202
T::from_pyany(pyany, py).map(Some)
203
}
204
}
205
}
206
207
impl<T, U> super::PlanCallbackArgs for (T, U)
208
where
209
T: super::PlanCallbackArgs,
210
U: super::PlanCallbackArgs,
211
{
212
fn into_pyany<'py>(self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
213
PyTuple::new(py, [self.0.into_pyany(py)?, self.1.into_pyany(py)?])?.into_py_any(py)
214
}
215
}
216
217
impl<T, U> super::PlanCallbackOut for (T, U)
218
where
219
T: super::PlanCallbackOut,
220
U: super::PlanCallbackOut,
221
{
222
fn from_pyany<'py>(
223
pyany: pyo3::Py<pyo3::PyAny>,
224
py: pyo3::Python<'py>,
225
) -> pyo3::PyResult<Self> {
226
use pyo3::prelude::*;
227
let tuple = pyany.cast_bound::<PyTuple>(py)?;
228
Ok((
229
T::from_pyany(tuple.get_item(0)?.unbind(), py)?,
230
U::from_pyany(tuple.get_item(1)?.unbind(), py)?,
231
))
232
}
233
}
234
235
impl_pycb_type! {
236
bool,
237
usize,
238
String
239
}
240
impl_pycb_type_to_from! {
241
PlSmallStr => String
242
}
243
impl_registrycb_type! {
244
(polars_core::series::Series, series, series),
245
(polars_core::frame::DataFrame, df, df),
246
(crate::dsl::DslPlan, dsl_plan, dsl_plan),
247
(polars_core::schema::Schema, schema, schema)
248
}
249
250
impl<T: super::PlanCallbackArgs + Clone> super::PlanCallbackArgs for Arc<T> {
251
fn into_pyany<'py>(self, py: Python<'py>) -> PyResult<Py<PyAny>> {
252
Arc::unwrap_or_clone(self).into_pyany(py)
253
}
254
}
255
256
impl<T: super::PlanCallbackArgs + Clone> super::PlanCallbackArgs for Vec<T> {
257
fn into_pyany<'py>(self, py: Python<'py>) -> PyResult<Py<PyAny>> {
258
let items: Vec<Py<PyAny>> = self
259
.into_iter()
260
.map(|v| v.into_pyany(py))
261
.collect::<PyResult<Vec<_>>>()?;
262
263
Ok(PyList::new(py, items)?.into())
264
}
265
}
266
267
impl<T: super::PlanCallbackOut> super::PlanCallbackOut for Arc<T> {
268
fn from_pyany<'py>(pyany: Py<PyAny>, py: Python<'py>) -> PyResult<Self> {
269
T::from_pyany(pyany, py).map(Arc::from)
270
}
271
}
272
}
273
274
#[cfg(not(feature = "python"))]
275
mod _no_python {
276
impl<T> super::PlanCallbackArgs for T {}
277
impl<T: Sized> super::PlanCallbackOut for T {}
278
}
279
280
impl<Args: PlanCallbackArgs, Out: PlanCallbackOut> PlanCallback<Args, Out> {
281
pub fn call(&self, args: Args) -> PolarsResult<Out> {
282
match self {
283
#[cfg(feature = "python")]
284
Self::Python(pyfn) => pyo3::Python::attach(|py| {
285
let out = Out::from_pyany(pyfn.call1(py, (args.into_pyany(py)?,))?, py)?;
286
Ok(out)
287
}),
288
Self::Rust(f) => f(args),
289
}
290
}
291
292
#[cfg(feature = "python")]
293
pub fn new_python(pyfn: polars_utils::python_function::PythonFunction) -> Self {
294
Self::Python(SpecialEq::new(Arc::new(pyfn)))
295
}
296
297
pub fn new(f: impl Fn(Args) -> PolarsResult<Out> + Send + Sync + 'static) -> Self {
298
Self::Rust(SpecialEq::new(Arc::new(f) as _))
299
}
300
}
301
302