Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/expr/expr_dyn_fn.rs
6940 views
1
use std::fmt::Formatter;
2
use std::ops::Deref;
3
use std::sync::Arc;
4
5
use super::*;
6
7
pub trait AnonymousColumnsUdf: ColumnsUdf {
8
fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf>;
9
fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf>;
10
11
fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
12
polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
13
}
14
15
fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field>;
16
}
17
18
/// A wrapper trait for any closure `Fn(Vec<Series>) -> PolarsResult<Series>`
19
pub trait ColumnsUdf: Send + Sync {
20
fn as_any(&self) -> &dyn std::any::Any {
21
unimplemented!("as_any not implemented for this 'opaque' function")
22
}
23
24
fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column>;
25
}
26
27
impl<F> ColumnsUdf for F
28
where
29
F: Fn(&mut [Column]) -> PolarsResult<Column> + Send + Sync,
30
{
31
fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {
32
self(s)
33
}
34
}
35
36
impl Debug for dyn ColumnsUdf {
37
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38
write!(f, "ColumnUdf")
39
}
40
}
41
42
#[derive(Clone)]
43
/// Wrapper type that has special equality properties
44
/// depending on the inner type specialization
45
pub struct SpecialEq<T>(T);
46
47
impl<T> SpecialEq<T> {
48
pub fn new(val: T) -> Self {
49
SpecialEq(val)
50
}
51
52
pub fn into_inner(self) -> T {
53
self.0
54
}
55
}
56
57
impl SpecialEq<Arc<dyn AnonymousColumnsUdf>> {
58
pub fn deep_clone(self) -> Self {
59
SpecialEq(self.0.deep_clone())
60
}
61
}
62
63
impl<T: ?Sized> PartialEq for SpecialEq<Arc<T>> {
64
fn eq(&self, other: &Self) -> bool {
65
Arc::ptr_eq(&self.0, &other.0)
66
}
67
}
68
69
impl<T: ?Sized> Eq for SpecialEq<Arc<T>> {}
70
71
impl<T: ?Sized> Hash for SpecialEq<Arc<T>> {
72
fn hash<H: Hasher>(&self, state: &mut H) {
73
Arc::as_ptr(self).hash(state);
74
}
75
}
76
77
impl PartialEq for SpecialEq<Series> {
78
fn eq(&self, other: &Self) -> bool {
79
self.0 == other.0
80
}
81
}
82
83
impl<T> Debug for SpecialEq<T> {
84
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
85
write!(f, "no_eq")
86
}
87
}
88
89
impl<T> Deref for SpecialEq<T> {
90
type Target = T;
91
92
fn deref(&self) -> &Self::Target {
93
&self.0
94
}
95
}
96
97
pub struct BaseColumnUdf<F, DT> {
98
f: F,
99
dt: DT,
100
}
101
102
impl<F, DT> BaseColumnUdf<F, DT> {
103
pub fn new(f: F, dt: DT) -> Self {
104
Self { f, dt }
105
}
106
}
107
108
impl<F, DT> ColumnsUdf for BaseColumnUdf<F, DT>
109
where
110
F: Fn(&mut [Column]) -> PolarsResult<Column> + Send + Sync,
111
DT: Fn(&Schema, &[Field]) -> PolarsResult<Field> + Send + Sync,
112
{
113
fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {
114
(self.f)(s)
115
}
116
}
117
118
impl<F, DT> AnonymousColumnsUdf for BaseColumnUdf<F, DT>
119
where
120
F: Fn(&mut [Column]) -> PolarsResult<Column> + 'static + Send + Sync,
121
DT: Fn(&Schema, &[Field]) -> PolarsResult<Field> + 'static + Send + Sync,
122
{
123
fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {
124
self as _
125
}
126
fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {
127
self
128
}
129
130
fn get_field(&self, input_schema: &Schema, fields: &[Field]) -> PolarsResult<Field> {
131
(self.dt)(input_schema, fields)
132
}
133
}
134
135
pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn AnonymousColumnsUdf>>>;
136
pub(crate) fn new_column_udf<F: AnonymousColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
137
LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
138
}
139
140
impl OpaqueColumnUdf {
141
pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn AnonymousColumnsUdf>>> {
142
match self {
143
Self::Deserialized(t) => Ok(t),
144
Self::Named {
145
name,
146
payload,
147
value,
148
} => feature_gated!("serde", {
149
use super::named_serde::NAMED_SERDE_REGISTRY_EXPR;
150
match value {
151
Some(v) => Ok(v),
152
None => Ok(SpecialEq(
153
NAMED_SERDE_REGISTRY_EXPR
154
.read()
155
.unwrap()
156
.as_ref()
157
.expect("NAMED EXPR REGISTRY NOT SET")
158
.get_function(&name, payload.unwrap().as_ref())
159
.expect("NAMED FUNCTION NOT FOUND"),
160
)),
161
}
162
}),
163
Self::Bytes(_b) => {
164
feature_gated!("serde";"python", {
165
serde_expr::deserialize_column_udf(_b.as_ref()).map(SpecialEq::new)
166
})
167
},
168
}
169
}
170
}
171
172