Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/expr/anonymous/serde_expr.rs
8416 views
1
use std::sync::Arc;
2
3
use polars_core::series::Series;
4
use polars_error::*;
5
use polars_utils::pl_serialize::deserialize_map_bytes;
6
use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8
use super::named_serde::ExprRegistry;
9
use super::*;
10
use crate::dsl::LazySerde;
11
12
const NAMED_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLNAMEDFN".as_bytes();
13
const NAMED_SERDE_MAGIC_BYTE_END: u8 = b'!';
14
15
fn serialize_named<S: Serializer>(
16
serializer: S,
17
name: &str,
18
payload: Option<&[u8]>,
19
) -> Result<S::Ok, S::Error> {
20
let mut buf = vec![];
21
buf.extend_from_slice(NAMED_SERDE_MAGIC_BYTE_MARK);
22
buf.extend_from_slice(name.as_bytes());
23
buf.push(NAMED_SERDE_MAGIC_BYTE_END);
24
if let Some(payload) = payload {
25
buf.extend_from_slice(payload);
26
}
27
serializer.serialize_bytes(&buf)
28
}
29
30
fn deserialize_named_registry(buf: &[u8]) -> PolarsResult<(Arc<dyn ExprRegistry>, &str, &[u8])> {
31
let bytes = &buf[NAMED_SERDE_MAGIC_BYTE_MARK.len()..];
32
let Some(pos) = bytes.iter().position(|b| *b == NAMED_SERDE_MAGIC_BYTE_END) else {
33
polars_bail!(ComputeError: "named-serde expected magic byte end")
34
};
35
36
let Ok(name) = std::str::from_utf8(&bytes[..pos]) else {
37
polars_bail!(ComputeError: "named-serde name should be valid utf8")
38
};
39
let payload = &bytes[pos + 1..];
40
41
let registry = named_serde::NAMED_SERDE_REGISTRY_EXPR.read().unwrap();
42
match &*registry {
43
Some(reg) => Ok((reg.clone(), name, payload)),
44
None => polars_bail!(ComputeError: "named serde registry not set"),
45
}
46
}
47
48
impl Serialize for SpecialEq<Arc<dyn AnonymousAgg>> {
49
fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
50
where
51
S: serde::Serializer,
52
{
53
unreachable!("should not be hit")
54
}
55
}
56
57
impl Serialize for SpecialEq<Arc<dyn AnonymousColumnsUdf>> {
58
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
59
where
60
S: Serializer,
61
{
62
use serde::ser::Error;
63
let mut buf = vec![];
64
self.as_ref()
65
.try_serialize(&mut buf)
66
.map_err(|e| S::Error::custom(format!("{e}")))?;
67
serializer.serialize_bytes(&buf)
68
}
69
}
70
71
impl<T: Serialize + Clone> Serialize for LazySerde<T> {
72
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73
where
74
S: Serializer,
75
{
76
match self {
77
Self::Named {
78
name,
79
payload,
80
value: _,
81
} => serialize_named(serializer, name, payload.as_deref()),
82
Self::Deserialized(t) => t.serialize(serializer),
83
Self::Bytes(b) => b.serialize(serializer),
84
}
85
}
86
}
87
88
impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
89
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
90
where
91
D: Deserializer<'a>,
92
{
93
let buf = bytes::Bytes::deserialize(deserializer)?;
94
Ok(Self::Bytes(buf))
95
}
96
}
97
98
pub(super) fn deserialize_column_udf(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {
99
#[cfg(feature = "python")]
100
if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
101
return crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(buf);
102
};
103
104
if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {
105
let (reg, name, payload) = deserialize_named_registry(buf)?;
106
107
if let Some(func) = reg.get_function(name, payload) {
108
Ok(func)
109
} else {
110
let msg = "name not found in named serde registry";
111
polars_bail!(ComputeError: msg)
112
}
113
} else {
114
polars_bail!(ComputeError: "deserialization not supported for this 'opaque' function")
115
}
116
}
117
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn AnonymousColumnsUdf>> {
118
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119
where
120
D: Deserializer<'a>,
121
{
122
use serde::de::Error;
123
deserialize_map_bytes(deserializer, |buf| {
124
deserialize_column_udf(&buf)
125
.map_err(|e| D::Error::custom(format!("{e}")))
126
.map(SpecialEq::new)
127
})?
128
}
129
}
130
131
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn AnonymousAgg>> {
132
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
133
where
134
D: Deserializer<'a>,
135
{
136
use serde::de::Error;
137
deserialize_map_bytes(deserializer, |buf| {
138
deserialize_anon_agg(&buf)
139
.map_err(|e| D::Error::custom(format!("{e}")))
140
.map(SpecialEq::new)
141
})?
142
}
143
}
144
145
pub(super) fn deserialize_anon_agg(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousAgg>> {
146
if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {
147
let (reg, name, payload) = deserialize_named_registry(buf)?;
148
149
if let Some(func) = reg.get_agg(name, payload)? {
150
Ok(func)
151
} else {
152
let msg = "name not found in named serde registry";
153
polars_bail!(ComputeError: msg)
154
}
155
} else {
156
polars_bail!(ComputeError: "deserialization not supported for this 'opaque' function")
157
}
158
}
159
160
// Serialize SpecialEq<T>
161
162
impl Serialize for SpecialEq<Series> {
163
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
164
where
165
S: Serializer,
166
{
167
let s: &Series = self;
168
s.serialize(serializer)
169
}
170
}
171
172
impl<'a> Deserialize<'a> for SpecialEq<Series> {
173
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
174
where
175
D: Deserializer<'a>,
176
{
177
let t = Series::deserialize(deserializer)?;
178
Ok(SpecialEq::new(t))
179
}
180
}
181
182
impl<T: Serialize> Serialize for SpecialEq<Arc<T>> {
183
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
184
where
185
S: Serializer,
186
{
187
self.as_ref().serialize(serializer)
188
}
189
}
190
191
#[cfg(feature = "serde")]
192
impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq<Arc<T>> {
193
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
194
where
195
D: Deserializer<'a>,
196
{
197
let t = T::deserialize(deserializer)?;
198
Ok(SpecialEq::new(Arc::new(t)))
199
}
200
}
201
202