Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-utils/src/pl_serialize.rs
6939 views
1
//! Centralized Polars serialization entry.
2
//!
3
//! Currently provides two serialization scheme's.
4
//! - Self-describing (and thus more forward compatible) activated with `FC: true`
5
//! - Compact activated with `FC: false`
6
use polars_error::{PolarsResult, to_compute_err};
7
8
fn config() -> bincode::config::Configuration {
9
bincode::config::standard()
10
.with_no_limit()
11
.with_variable_int_encoding()
12
}
13
14
fn serialize_impl<W, T, const FC: bool>(mut writer: W, value: &T) -> PolarsResult<()>
15
where
16
W: std::io::Write,
17
T: serde::ser::Serialize,
18
{
19
if FC {
20
let mut s = rmp_serde::Serializer::new(writer).with_struct_map();
21
value.serialize(&mut s).map_err(to_compute_err)
22
} else {
23
bincode::serde::encode_into_std_write(value, &mut writer, config())
24
.map_err(to_compute_err)
25
.map(|_| ())
26
}
27
}
28
29
pub fn deserialize_impl<T, R, const FC: bool>(mut reader: R) -> PolarsResult<T>
30
where
31
T: serde::de::DeserializeOwned,
32
R: std::io::Read,
33
{
34
if FC {
35
rmp_serde::from_read(reader).map_err(to_compute_err)
36
} else {
37
bincode::serde::decode_from_std_read(&mut reader, config()).map_err(to_compute_err)
38
}
39
}
40
41
/// Mainly used to enable compression when serializing the final outer value.
42
/// For intermediate serialization steps, the function in the module should
43
/// be used instead.
44
pub struct SerializeOptions {
45
compression: bool,
46
}
47
48
impl SerializeOptions {
49
pub fn with_compression(mut self, compression: bool) -> Self {
50
self.compression = compression;
51
self
52
}
53
54
pub fn serialize_into_writer<W, T, const FC: bool>(
55
&self,
56
writer: W,
57
value: &T,
58
) -> PolarsResult<()>
59
where
60
W: std::io::Write,
61
T: serde::ser::Serialize,
62
{
63
if self.compression {
64
let writer = flate2::write::ZlibEncoder::new(writer, flate2::Compression::fast());
65
serialize_impl::<_, _, FC>(writer, value)
66
} else {
67
serialize_impl::<_, _, FC>(writer, value)
68
}
69
}
70
71
pub fn deserialize_from_reader<T, R, const FC: bool>(&self, reader: R) -> PolarsResult<T>
72
where
73
T: serde::de::DeserializeOwned,
74
R: std::io::Read,
75
{
76
if self.compression {
77
deserialize_impl::<_, _, FC>(flate2::read::ZlibDecoder::new(reader))
78
} else {
79
deserialize_impl::<_, _, FC>(reader)
80
}
81
}
82
83
pub fn serialize_to_bytes<T, const FC: bool>(&self, value: &T) -> PolarsResult<Vec<u8>>
84
where
85
T: serde::ser::Serialize,
86
{
87
let mut v = vec![];
88
89
self.serialize_into_writer::<_, _, FC>(&mut v, value)?;
90
91
Ok(v)
92
}
93
}
94
95
#[allow(clippy::derivable_impls)]
96
impl Default for SerializeOptions {
97
fn default() -> Self {
98
Self { compression: false }
99
}
100
}
101
102
pub fn serialize_into_writer<W, T, const FC: bool>(writer: W, value: &T) -> PolarsResult<()>
103
where
104
W: std::io::Write,
105
T: serde::ser::Serialize,
106
{
107
serialize_impl::<_, _, FC>(writer, value)
108
}
109
110
pub fn deserialize_from_reader<T, R, const FC: bool>(reader: R) -> PolarsResult<T>
111
where
112
T: serde::de::DeserializeOwned,
113
R: std::io::Read,
114
{
115
deserialize_impl::<_, _, FC>(reader)
116
}
117
118
pub fn serialize_to_bytes<T, const FC: bool>(value: &T) -> PolarsResult<Vec<u8>>
119
where
120
T: serde::ser::Serialize,
121
{
122
let mut v = vec![];
123
124
serialize_into_writer::<_, _, FC>(&mut v, value)?;
125
126
Ok(v)
127
}
128
129
/// Serialize function customized for `DslPlan`, with stack overflow protection.
130
pub fn serialize_dsl<W, T>(writer: W, value: &T) -> PolarsResult<()>
131
where
132
W: std::io::Write,
133
T: serde::ser::Serialize,
134
{
135
let mut s = rmp_serde::Serializer::new(writer).with_struct_map();
136
let s = serde_stacker::Serializer::new(&mut s);
137
value.serialize(s).map_err(to_compute_err)
138
}
139
140
/// Deserialize function customized for `DslPlan`, with stack overflow protection.
141
pub fn deserialize_dsl<T, R>(reader: R) -> PolarsResult<T>
142
where
143
T: serde::de::DeserializeOwned,
144
R: std::io::Read,
145
{
146
let mut de = rmp_serde::Deserializer::new(reader);
147
de.set_max_depth(usize::MAX);
148
let de = serde_stacker::Deserializer::new(&mut de);
149
T::deserialize(de).map_err(to_compute_err)
150
}
151
152
/// Potentially avoids copying memory compared to a naive `Vec::<u8>::deserialize`.
153
///
154
/// This is essentially boilerplate for visiting bytes without copying where possible.
155
pub fn deserialize_map_bytes<'de, D, O>(
156
deserializer: D,
157
mut func: impl for<'b> FnMut(std::borrow::Cow<'b, [u8]>) -> O,
158
) -> Result<O, D::Error>
159
where
160
D: serde::de::Deserializer<'de>,
161
{
162
// Lets us avoid monomorphizing the visitor
163
let mut out: Option<O> = None;
164
struct V<'f>(&'f mut dyn for<'b> FnMut(std::borrow::Cow<'b, [u8]>));
165
166
deserializer.deserialize_bytes(V(&mut |v| drop(out.replace(func(v)))))?;
167
168
return Ok(out.unwrap());
169
170
impl<'de> serde::de::Visitor<'de> for V<'_> {
171
type Value = ();
172
173
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
174
formatter.write_str("deserialize_map_bytes")
175
}
176
177
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
178
where
179
E: serde::de::Error,
180
{
181
self.0(std::borrow::Cow::Borrowed(v));
182
Ok(())
183
}
184
185
fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
186
where
187
E: serde::de::Error,
188
{
189
self.0(std::borrow::Cow::Owned(v));
190
Ok(())
191
}
192
193
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
194
where
195
A: serde::de::SeqAccess<'de>,
196
{
197
// This is not ideal, but we hit here if the serialization format is JSON.
198
let bytes = std::iter::from_fn(|| seq.next_element::<u8>().transpose())
199
.collect::<Result<Vec<_>, A::Error>>()?;
200
201
self.0(std::borrow::Cow::Owned(bytes));
202
Ok(())
203
}
204
}
205
}
206
207
thread_local! {
208
pub static USE_CLOUDPICKLE: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
209
}
210
211
#[cfg(feature = "python")]
212
pub fn python_object_serialize(
213
pyobj: &pyo3::Py<pyo3::PyAny>,
214
buf: &mut Vec<u8>,
215
) -> PolarsResult<()> {
216
use pyo3::Python;
217
use pyo3::pybacked::PyBackedBytes;
218
use pyo3::types::{PyAnyMethods, PyModule};
219
220
use crate::python_function::PYTHON3_VERSION;
221
222
let mut use_cloudpickle = USE_CLOUDPICKLE.get();
223
let dumped = Python::with_gil(|py| {
224
// Pickle with whatever pickling method was selected.
225
if use_cloudpickle {
226
let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
227
cloudpickle.call1((pyobj.clone_ref(py),))?
228
} else {
229
let pickle = PyModule::import(py, "pickle")?.getattr("dumps")?;
230
match pickle.call1((pyobj.clone_ref(py),)) {
231
Ok(dumped) => dumped,
232
Err(_) => {
233
use_cloudpickle = true;
234
let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
235
cloudpickle.call1((pyobj.clone_ref(py),))?
236
},
237
}
238
}
239
.extract::<PyBackedBytes>()
240
})?;
241
242
// Write pickle metadata
243
buf.push(use_cloudpickle as u8);
244
buf.extend_from_slice(&*PYTHON3_VERSION);
245
246
// Write UDF
247
buf.extend_from_slice(&dumped);
248
Ok(())
249
}
250
251
#[cfg(feature = "python")]
252
pub fn python_object_deserialize(buf: &[u8]) -> PolarsResult<pyo3::Py<pyo3::PyAny>> {
253
use polars_error::polars_ensure;
254
use pyo3::Python;
255
use pyo3::types::{PyAnyMethods, PyBytes, PyModule};
256
257
use crate::python_function::PYTHON3_VERSION;
258
259
// Handle pickle metadata
260
let use_cloudpickle = buf[0] != 0;
261
if use_cloudpickle {
262
let ser_py_version = &buf[1..3];
263
let cur_py_version = *PYTHON3_VERSION;
264
polars_ensure!(
265
ser_py_version == cur_py_version,
266
InvalidOperation:
267
"current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
268
(3, cur_py_version[0], cur_py_version[1]),
269
(3, ser_py_version[0], ser_py_version[1] )
270
);
271
}
272
let buf = &buf[3..];
273
274
Python::with_gil(|py| {
275
let loads = PyModule::import(py, "pickle")?.getattr("loads")?;
276
let arg = (PyBytes::new(py, buf),);
277
let python_function = loads.call1(arg)?;
278
Ok(python_function.into())
279
})
280
}
281
282
#[cfg(test)]
283
mod tests {
284
#[test]
285
fn test_serde_skip_enum() {
286
#[derive(Default, Debug, PartialEq)]
287
struct MyType(Option<usize>);
288
289
// Note: serde(skip) must be at the end of enums
290
#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
291
enum Enum {
292
A,
293
#[serde(skip)]
294
B(MyType),
295
}
296
297
impl Default for Enum {
298
fn default() -> Self {
299
Self::B(MyType(None))
300
}
301
}
302
303
let v = Enum::A;
304
let b = super::serialize_to_bytes::<_, false>(&v).unwrap();
305
let r: Enum = super::deserialize_from_reader::<_, _, false>(b.as_slice()).unwrap();
306
307
assert_eq!(r, v);
308
309
let v = Enum::A;
310
let b = super::SerializeOptions::default()
311
.serialize_to_bytes::<_, false>(&v)
312
.unwrap();
313
let r: Enum = super::SerializeOptions::default()
314
.deserialize_from_reader::<_, _, false>(b.as_slice())
315
.unwrap();
316
317
assert_eq!(r, v);
318
}
319
}
320
321