Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-utils/src/python_function.rs
6939 views
1
use pyo3::prelude::*;
2
#[cfg(feature = "serde")]
3
pub use serde_wrap::{
4
PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
5
TrySerializeToBytes,
6
};
7
8
/// Wrapper around PyObject from pyo3 with additional trait impls.
9
#[derive(Debug)]
10
pub struct PythonObject(pub PyObject);
11
// Note: We have this because the struct itself used to be called `PythonFunction`, so it's
12
// referred to as such from a lot of places.
13
pub type PythonFunction = PythonObject;
14
15
impl std::ops::Deref for PythonObject {
16
type Target = PyObject;
17
18
fn deref(&self) -> &Self::Target {
19
&self.0
20
}
21
}
22
23
impl std::ops::DerefMut for PythonObject {
24
fn deref_mut(&mut self) -> &mut Self::Target {
25
&mut self.0
26
}
27
}
28
29
impl Clone for PythonObject {
30
fn clone(&self) -> Self {
31
Python::with_gil(|py| Self(self.0.clone_ref(py)))
32
}
33
}
34
35
impl From<PyObject> for PythonObject {
36
fn from(value: PyObject) -> Self {
37
Self(value)
38
}
39
}
40
41
impl<'py> pyo3::conversion::IntoPyObject<'py> for PythonObject {
42
type Target = PyAny;
43
type Output = Bound<'py, Self::Target>;
44
type Error = PyErr;
45
46
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
47
Ok(self.0.into_bound(py))
48
}
49
}
50
51
impl<'py> pyo3::conversion::IntoPyObject<'py> for &PythonObject {
52
type Target = PyAny;
53
type Output = Bound<'py, Self::Target>;
54
type Error = PyErr;
55
56
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
57
Ok(self.0.bind(py).clone())
58
}
59
}
60
61
impl Eq for PythonObject {}
62
63
impl PartialEq for PythonObject {
64
fn eq(&self, other: &Self) -> bool {
65
Python::with_gil(|py| {
66
let eq = self.0.getattr(py, "__eq__").unwrap();
67
eq.call1(py, (other.0.clone_ref(py),))
68
.unwrap()
69
.extract::<bool>(py)
70
// equality can be not implemented, so default to false
71
.unwrap_or(false)
72
})
73
}
74
}
75
76
#[cfg(feature = "dsl-schema")]
77
impl schemars::JsonSchema for PythonObject {
78
fn schema_name() -> String {
79
"PythonObject".to_owned()
80
}
81
82
fn schema_id() -> std::borrow::Cow<'static, str> {
83
std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PythonObject"))
84
}
85
86
fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
87
Vec::<u8>::json_schema(generator)
88
}
89
}
90
91
#[cfg(feature = "serde")]
92
mod _serde_impls {
93
use super::{PySerializeWrap, PythonObject, TrySerializeToBytes};
94
use crate::pl_serialize::deserialize_map_bytes;
95
96
impl PythonObject {
97
pub fn serialize_with_pyversion<T, S>(
98
value: &T,
99
serializer: S,
100
) -> std::result::Result<S::Ok, S::Error>
101
where
102
T: AsRef<PythonObject>,
103
S: serde::ser::Serializer,
104
{
105
use serde::Serialize;
106
PySerializeWrap(value.as_ref()).serialize(serializer)
107
}
108
109
pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result<T, D::Error>
110
where
111
T: From<PythonObject>,
112
D: serde::de::Deserializer<'de>,
113
{
114
use serde::Deserialize;
115
let v: PySerializeWrap<PythonObject> = PySerializeWrap::deserialize(d)?;
116
117
Ok(v.0.into())
118
}
119
}
120
121
impl TrySerializeToBytes for PythonObject {
122
fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
123
let mut buf = Vec::new();
124
crate::pl_serialize::python_object_serialize(&self.0, &mut buf)?;
125
Ok(buf)
126
}
127
128
fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
129
crate::pl_serialize::python_object_deserialize(bytes).map(PythonObject)
130
}
131
}
132
133
impl serde::Serialize for PythonObject {
134
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
135
where
136
S: serde::Serializer,
137
{
138
use serde::ser::Error;
139
let bytes = self
140
.try_serialize_to_bytes()
141
.map_err(|e| S::Error::custom(e.to_string()))?;
142
143
Vec::<u8>::serialize(&bytes, serializer)
144
}
145
}
146
147
impl<'a> serde::Deserialize<'a> for PythonObject {
148
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
149
where
150
D: serde::Deserializer<'a>,
151
{
152
use serde::de::Error;
153
deserialize_map_bytes(deserializer, |bytes| {
154
Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string()))
155
})?
156
}
157
}
158
}
159
160
#[cfg(feature = "serde")]
161
mod serde_wrap {
162
use std::sync::LazyLock;
163
164
use polars_error::PolarsResult;
165
166
use crate::pl_serialize::deserialize_map_bytes;
167
168
pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
169
/// [minor, micro]
170
pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version);
171
172
/// Serializes a Python object without additional system metadata. This is intended to be used
173
/// together with `PySerializeWrap`, which attaches e.g. Python version metadata.
174
pub trait TrySerializeToBytes: Sized {
175
fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
176
fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
177
}
178
179
/// Serialization wrapper for T: TrySerializeToBytes that attaches Python
180
/// version metadata.
181
pub struct PySerializeWrap<T>(pub T);
182
183
impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
184
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
185
where
186
S: serde::Serializer,
187
{
188
use serde::ser::Error;
189
let dumped = self
190
.0
191
.try_serialize_to_bytes()
192
.map_err(|e| S::Error::custom(e.to_string()))?;
193
194
serializer.serialize_bytes(dumped.as_slice())
195
}
196
}
197
198
impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
199
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
200
where
201
D: serde::Deserializer<'a>,
202
{
203
use serde::de::Error;
204
205
deserialize_map_bytes(deserializer, |bytes| {
206
T::try_deserialize_bytes(bytes.as_ref())
207
.map(Self)
208
.map_err(|e| D::Error::custom(e.to_string()))
209
})?
210
}
211
}
212
}
213
214
/// Get the [minor, micro] Python3 version from the `sys` module.
215
fn get_python3_version() -> [u8; 2] {
216
Python::with_gil(|py| {
217
let version_info = PyModule::import(py, "sys")
218
.unwrap()
219
.getattr("version_info")
220
.unwrap();
221
222
[
223
version_info.getattr("minor").unwrap().extract().unwrap(),
224
version_info.getattr("micro").unwrap().extract().unwrap(),
225
]
226
})
227
}
228
229