Path: blob/main/crates/polars-plan/src/dsl/expr/anonymous/expr.rs
8415 views
use std::fmt::{Debug, Formatter};1use std::hash::{Hash, Hasher};2use std::ops::Deref;3use std::sync::Arc;45use polars_core::prelude::*;6use polars_error::{PolarsResult, feature_gated, polars_bail};78#[cfg(feature = "serde")]9use super::serde_expr;10use crate::dsl::LazySerde;1112pub trait AnonymousColumnsUdf: ColumnsUdf {13fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf>;14fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf>;1516fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {17polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")18}1920fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field>;21}2223/// A wrapper trait for any closure `Fn(Vec<Series>) -> PolarsResult<Series>`24pub trait ColumnsUdf: Send + Sync {25fn as_any(&self) -> &dyn std::any::Any {26unimplemented!("as_any not implemented for this 'opaque' function")27}2829fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column>;30}3132impl<F> ColumnsUdf for F33where34F: Fn(&mut [Column]) -> PolarsResult<Column> + Send + Sync,35{36fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {37self(s)38}39}4041impl Debug for dyn ColumnsUdf {42fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {43write!(f, "ColumnUdf")44}45}4647#[derive(Clone)]48/// Wrapper type that has special equality properties49/// depending on the inner type specialization50pub struct SpecialEq<T>(T);5152impl<T> SpecialEq<T> {53pub fn new(val: T) -> Self {54SpecialEq(val)55}5657pub fn into_inner(self) -> T {58self.059}60}6162impl SpecialEq<Arc<dyn AnonymousColumnsUdf>> {63pub fn deep_clone(self) -> Self {64SpecialEq(self.0.deep_clone())65}66}6768impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {69fn eq(&self, other: &Self) -> bool {70Arc::ptr_eq(&self.0, &other.0)71}72}7374impl<T: ?Sized> Eq for SpecialEq<Arc<T>> {}7576impl<T: ?Sized> Hash for SpecialEq<Arc<T>> {77fn hash<H: Hasher>(&self, state: &mut H) {78Arc::as_ptr(self).hash(state);79}80}8182impl PartialEq for SpecialEq<Series> {83fn eq(&self, other: &Self) -> bool {84self.0 == other.085}86}8788impl<T> Debug for SpecialEq<T> {89fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {90write!(f, "no_eq")91}92}9394impl<T> Deref for SpecialEq<T> {95type Target = T;9697fn deref(&self) -> &Self::Target {98&self.099}100}101102pub struct BaseColumnUdf<F, DT> {103f: F,104dt: DT,105}106107impl<F, DT> BaseColumnUdf<F, DT> {108pub fn new(f: F, dt: DT) -> Self {109Self { f, dt }110}111}112113impl<F, DT> ColumnsUdf for BaseColumnUdf<F, DT>114where115F: Fn(&mut [Column]) -> PolarsResult<Column> + Send + Sync,116DT: Fn(&Schema, &[Field]) -> PolarsResult<Field> + Send + Sync,117{118fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {119(self.f)(s)120}121}122123impl<F, DT> AnonymousColumnsUdf for BaseColumnUdf<F, DT>124where125F: Fn(&mut [Column]) -> PolarsResult<Column> + 'static + Send + Sync,126DT: Fn(&Schema, &[Field]) -> PolarsResult<Field> + 'static + Send + Sync,127{128fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {129self as _130}131fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {132self133}134135fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {136(self.dt)(input_schema, fields)137}138}139140pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn AnonymousColumnsUdf>>>;141142impl Hash for OpaqueColumnUdf {143fn hash<H: Hasher>(&self, state: &mut H) {144core::mem::discriminant(self).hash(state);145match self {146Self::Deserialized(ptr) => ptr.hash(state),147Self::Bytes(b) => b.hash(state),148Self::Named {149name,150payload,151value: _,152} => {153name.hash(state);154payload.hash(state);155},156}157}158}159160pub fn new_column_udf<F: AnonymousColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {161LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))162}163164impl OpaqueColumnUdf {165pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn AnonymousColumnsUdf>>> {166match self {167Self::Deserialized(t) => Ok(t),168Self::Named {169name,170payload,171value,172} => feature_gated!("serde", {173use super::named_serde::NAMED_SERDE_REGISTRY_EXPR;174match value {175Some(v) => Ok(v),176None => Ok(SpecialEq(177NAMED_SERDE_REGISTRY_EXPR178.read()179.unwrap()180.as_ref()181.expect("NAMED EXPR REGISTRY NOT SET")182.get_function(&name, payload.unwrap().as_ref())183.expect("NAMED FUNCTION NOT FOUND"),184)),185}186}),187Self::Bytes(_b) => {188feature_gated!("serde";"python", {189serde_expr::deserialize_column_udf(_b.as_ref()).map(SpecialEq::new)190})191},192}193}194}195196197