Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/callback.rs
6939 views
1
use std::fmt;
2
use std::sync::Arc;
3
4
use polars_error::PolarsResult;
5
6
use crate::dsl::SpecialEq;
7
8
#[derive(Eq, PartialEq)]
9
pub enum PlanCallback<Args, Out> {
10
#[cfg(feature = "python")]
11
Python(SpecialEq<Arc<polars_utils::python_function::PythonFunction>>),
12
Rust(SpecialEq<Arc<dyn Fn(Args) -> PolarsResult<Out> + Send + Sync>>),
13
}
14
15
impl<Args, Out> fmt::Debug for PlanCallback<Args, Out> {
16
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
17
f.write_str("PlanCallback::")?;
18
std::mem::discriminant(self).fmt(f)
19
}
20
}
21
22
#[cfg(feature = "serde")]
23
impl<Args, Out> serde::Serialize for PlanCallback<Args, Out> {
24
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
25
where
26
S: serde::Serializer,
27
{
28
use serde::ser::Error;
29
30
#[cfg(feature = "python")]
31
if let Self::Python(v) = self {
32
return v.serialize(_serializer);
33
}
34
35
Err(S::Error::custom(format!(
36
"cannot serialize 'opaque' function in {self:?}"
37
)))
38
}
39
}
40
41
#[cfg(feature = "serde")]
42
impl<'de, Args, Out> serde::Deserialize<'de> for PlanCallback<Args, Out> {
43
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
44
where
45
D: serde::Deserializer<'de>,
46
{
47
#[cfg(feature = "python")]
48
{
49
Ok(Self::Python(SpecialEq::new(Arc::new(
50
polars_utils::python_function::PythonFunction::deserialize(_deserializer)?,
51
))))
52
}
53
#[cfg(not(feature = "python"))]
54
{
55
use serde::de::Error;
56
Err(D::Error::custom("cannot deserialize PlanCallback"))
57
}
58
}
59
}
60
61
#[cfg(feature = "dsl-schema")]
62
impl<Args, Out> schemars::JsonSchema for PlanCallback<Args, Out> {
63
fn schema_name() -> String {
64
"PlanCallback".to_owned()
65
}
66
67
fn schema_id() -> std::borrow::Cow<'static, str> {
68
std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PlanCallback"))
69
}
70
71
fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
72
Vec::<u8>::json_schema(generator)
73
}
74
}
75
76
impl<Args, Out> std::hash::Hash for PlanCallback<Args, Out> {
77
fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {
78
// no-op.
79
}
80
}
81
82
impl<Args, Out> Clone for PlanCallback<Args, Out> {
83
fn clone(&self) -> Self {
84
match self {
85
#[cfg(feature = "python")]
86
Self::Python(p) => Self::Python(p.clone()),
87
Self::Rust(f) => Self::Rust(f.clone()),
88
}
89
}
90
}
91
92
pub trait PlanCallbackArgs {
93
#[cfg(feature = "python")]
94
fn into_pyany<'py>(self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>>;
95
}
96
pub trait PlanCallbackOut: Sized {
97
#[cfg(feature = "python")]
98
fn from_pyany<'py>(pyany: pyo3::Py<pyo3::PyAny>, py: pyo3::Python<'py>)
99
-> pyo3::PyResult<Self>;
100
}
101
102
#[cfg(feature = "python")]
103
mod _python {
104
use pyo3::types::{PyAnyMethods, PyTuple};
105
use pyo3::*;
106
107
macro_rules! impl_pycb_type {
108
($($type:ty),+) => {
109
$(
110
impl super::PlanCallbackArgs for $type {
111
fn into_pyany<'py>(self, py: Python<'py>) -> PyResult<Py<PyAny>> {
112
Ok(self.into_pyobject(py)?.into_any().unbind())
113
}
114
}
115
116
impl super::PlanCallbackOut for $type {
117
fn from_pyany<'py>(pyany: Py<PyAny>, py: Python<'py>) -> PyResult<Self> {
118
pyany.bind(py).extract::<Self>()
119
}
120
}
121
)+
122
};
123
}
124
125
macro_rules! impl_registrycb_type {
126
($(($type:path, $from:ident, $to:ident)),+) => {
127
$(
128
impl super::PlanCallbackArgs for $type {
129
fn into_pyany<'py>(self, _py: Python<'py>) -> PyResult<Py<PyAny>> {
130
let registry = polars_utils::python_convert_registry::get_python_convert_registry();
131
(registry.to_py.$to)(Box::new(self) as _)
132
}
133
}
134
135
impl super::PlanCallbackOut for $type {
136
fn from_pyany<'py>(pyany: Py<PyAny>, _py: Python<'py>) -> PyResult<Self> {
137
let registry = polars_utils::python_convert_registry::get_python_convert_registry();
138
let obj = (registry.from_py.$from)(pyany)?;
139
let obj = obj.downcast().unwrap();
140
Ok(*obj)
141
}
142
}
143
)+
144
};
145
}
146
147
impl<T: super::PlanCallbackArgs> super::PlanCallbackArgs for Option<T> {
148
fn into_pyany<'py>(self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
149
match self {
150
None => Ok(py.None()),
151
Some(v) => v.into_pyany(py),
152
}
153
}
154
}
155
156
impl<T: super::PlanCallbackOut> super::PlanCallbackOut for Option<T> {
157
fn from_pyany<'py>(
158
pyany: pyo3::Py<pyo3::PyAny>,
159
py: pyo3::Python<'py>,
160
) -> pyo3::PyResult<Self> {
161
if pyany.is_none(py) {
162
Ok(None)
163
} else {
164
T::from_pyany(pyany, py).map(Some)
165
}
166
}
167
}
168
169
impl<T, U> super::PlanCallbackArgs for (T, U)
170
where
171
T: super::PlanCallbackArgs,
172
U: super::PlanCallbackArgs,
173
{
174
fn into_pyany<'py>(self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
175
PyTuple::new(py, [self.0.into_pyany(py)?, self.1.into_pyany(py)?])?.into_py_any(py)
176
}
177
}
178
179
impl<T, U> super::PlanCallbackOut for (T, U)
180
where
181
T: super::PlanCallbackOut,
182
U: super::PlanCallbackOut,
183
{
184
fn from_pyany<'py>(
185
pyany: pyo3::Py<pyo3::PyAny>,
186
py: pyo3::Python<'py>,
187
) -> pyo3::PyResult<Self> {
188
use pyo3::prelude::*;
189
let tuple = pyany.downcast_bound::<PyTuple>(py)?;
190
Ok((
191
T::from_pyany(tuple.get_item(0)?.unbind(), py)?,
192
U::from_pyany(tuple.get_item(1)?.unbind(), py)?,
193
))
194
}
195
}
196
197
impl_pycb_type! {
198
usize,
199
String
200
}
201
impl_registrycb_type! {
202
(polars_core::series::Series, series, series),
203
(polars_core::frame::DataFrame, df, df),
204
(crate::dsl::DslPlan, dsl_plan, dsl_plan),
205
(polars_core::schema::Schema, schema, schema)
206
}
207
}
208
209
#[cfg(not(feature = "python"))]
210
mod _no_python {
211
impl<T> super::PlanCallbackArgs for T {}
212
impl<T: Sized> super::PlanCallbackOut for T {}
213
}
214
215
impl<Args: PlanCallbackArgs, Out: PlanCallbackOut> PlanCallback<Args, Out> {
216
pub fn call(&self, args: Args) -> PolarsResult<Out> {
217
match self {
218
#[cfg(feature = "python")]
219
Self::Python(pyfn) => pyo3::Python::with_gil(|py| {
220
let out = Out::from_pyany(pyfn.call1(py, (args.into_pyany(py)?,))?, py)?;
221
Ok(out)
222
}),
223
Self::Rust(f) => f(args),
224
}
225
}
226
227
#[cfg(feature = "python")]
228
pub fn new_python(pyfn: polars_utils::python_function::PythonFunction) -> Self {
229
Self::Python(SpecialEq::new(Arc::new(pyfn)))
230
}
231
232
pub fn new(f: impl Fn(Args) -> PolarsResult<Out> + Send + Sync + 'static) -> Self {
233
Self::Rust(SpecialEq::new(Arc::new(f) as _))
234
}
235
}
236
237