Path: blob/main/crates/polars-plan/src/dsl/expr/anonymous/serde_expr.rs
8416 views
use std::sync::Arc;12use polars_core::series::Series;3use polars_error::*;4use polars_utils::pl_serialize::deserialize_map_bytes;5use serde::{Deserialize, Deserializer, Serialize, Serializer};67use super::named_serde::ExprRegistry;8use super::*;9use crate::dsl::LazySerde;1011const NAMED_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLNAMEDFN".as_bytes();12const NAMED_SERDE_MAGIC_BYTE_END: u8 = b'!';1314fn serialize_named<S: Serializer>(15serializer: S,16name: &str,17payload: Option<&[u8]>,18) -> Result<S::Ok, S::Error> {19let mut buf = vec![];20buf.extend_from_slice(NAMED_SERDE_MAGIC_BYTE_MARK);21buf.extend_from_slice(name.as_bytes());22buf.push(NAMED_SERDE_MAGIC_BYTE_END);23if let Some(payload) = payload {24buf.extend_from_slice(payload);25}26serializer.serialize_bytes(&buf)27}2829fn deserialize_named_registry(buf: &[u8]) -> PolarsResult<(Arc<dyn ExprRegistry>, &str, &[u8])> {30let bytes = &buf[NAMED_SERDE_MAGIC_BYTE_MARK.len()..];31let Some(pos) = bytes.iter().position(|b| *b == NAMED_SERDE_MAGIC_BYTE_END) else {32polars_bail!(ComputeError: "named-serde expected magic byte end")33};3435let Ok(name) = std::str::from_utf8(&bytes[..pos]) else {36polars_bail!(ComputeError: "named-serde name should be valid utf8")37};38let payload = &bytes[pos + 1..];3940let registry = named_serde::NAMED_SERDE_REGISTRY_EXPR.read().unwrap();41match &*registry {42Some(reg) => Ok((reg.clone(), name, payload)),43None => polars_bail!(ComputeError: "named serde registry not set"),44}45}4647impl Serialize for SpecialEq<Arc<dyn AnonymousAgg>> {48fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>49where50S: serde::Serializer,51{52unreachable!("should not be hit")53}54}5556impl Serialize for SpecialEq<Arc<dyn AnonymousColumnsUdf>> {57fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>58where59S: Serializer,60{61use serde::ser::Error;62let mut buf = vec![];63self.as_ref()64.try_serialize(&mut buf)65.map_err(|e| S::Error::custom(format!("{e}")))?;66serializer.serialize_bytes(&buf)67}68}6970impl<T: Serialize + Clone> Serialize for LazySerde<T> {71fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>72where73S: Serializer,74{75match self {76Self::Named {77name,78payload,79value: _,80} => serialize_named(serializer, name, payload.as_deref()),81Self::Deserialized(t) => t.serialize(serializer),82Self::Bytes(b) => b.serialize(serializer),83}84}85}8687impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {88fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>89where90D: Deserializer<'a>,91{92let buf = bytes::Bytes::deserialize(deserializer)?;93Ok(Self::Bytes(buf))94}95}9697pub(super) fn deserialize_column_udf(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousColumnsUdf>> {98#[cfg(feature = "python")]99if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {100return crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(buf);101};102103if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {104let (reg, name, payload) = deserialize_named_registry(buf)?;105106if let Some(func) = reg.get_function(name, payload) {107Ok(func)108} else {109let msg = "name not found in named serde registry";110polars_bail!(ComputeError: msg)111}112} else {113polars_bail!(ComputeError: "deserialization not supported for this 'opaque' function")114}115}116impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn AnonymousColumnsUdf>> {117fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>118where119D: Deserializer<'a>,120{121use serde::de::Error;122deserialize_map_bytes(deserializer, |buf| {123deserialize_column_udf(&buf)124.map_err(|e| D::Error::custom(format!("{e}")))125.map(SpecialEq::new)126})?127}128}129130impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn AnonymousAgg>> {131fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>132where133D: Deserializer<'a>,134{135use serde::de::Error;136deserialize_map_bytes(deserializer, |buf| {137deserialize_anon_agg(&buf)138.map_err(|e| D::Error::custom(format!("{e}")))139.map(SpecialEq::new)140})?141}142}143144pub(super) fn deserialize_anon_agg(buf: &[u8]) -> PolarsResult<Arc<dyn AnonymousAgg>> {145if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {146let (reg, name, payload) = deserialize_named_registry(buf)?;147148if let Some(func) = reg.get_agg(name, payload)? {149Ok(func)150} else {151let msg = "name not found in named serde registry";152polars_bail!(ComputeError: msg)153}154} else {155polars_bail!(ComputeError: "deserialization not supported for this 'opaque' function")156}157}158159// Serialize SpecialEq<T>160161impl Serialize for SpecialEq<Series> {162fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>163where164S: Serializer,165{166let s: &Series = self;167s.serialize(serializer)168}169}170171impl<'a> Deserialize<'a> for SpecialEq<Series> {172fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>173where174D: Deserializer<'a>,175{176let t = Series::deserialize(deserializer)?;177Ok(SpecialEq::new(t))178}179}180181impl<T: Serialize> Serialize for SpecialEq<Arc<T>> {182fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>183where184S: Serializer,185{186self.as_ref().serialize(serializer)187}188}189190#[cfg(feature = "serde")]191impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq<Arc<T>> {192fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>193where194D: Deserializer<'a>,195{196let t = T::deserialize(deserializer)?;197Ok(SpecialEq::new(Arc::new(t)))198}199}200201202