Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/expr/mod.rs
6940 views
1
mod datatype_fn;
2
mod expr_dyn_fn;
3
use std::fmt::{Debug, Display, Formatter};
4
use std::hash::{Hash, Hasher};
5
6
use bytes::Bytes;
7
pub use datatype_fn::*;
8
pub use expr_dyn_fn::*;
9
use polars_compute::rolling::QuantileMethod;
10
use polars_core::chunked_array::cast::CastOptions;
11
use polars_core::error::feature_gated;
12
use polars_core::prelude::*;
13
use polars_utils::format_pl_smallstr;
14
#[cfg(feature = "serde")]
15
use serde::{Deserialize, Serialize};
16
#[cfg(feature = "serde")]
17
pub mod named_serde;
18
#[cfg(feature = "serde")]
19
mod serde_expr;
20
21
use super::datatype_expr::DataTypeExpr;
22
use crate::prelude::*;
23
24
#[derive(PartialEq, Clone, Hash)]
25
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
27
pub enum AggExpr {
28
Min {
29
input: Arc<Expr>,
30
propagate_nans: bool,
31
},
32
Max {
33
input: Arc<Expr>,
34
propagate_nans: bool,
35
},
36
Median(Arc<Expr>),
37
NUnique(Arc<Expr>),
38
First(Arc<Expr>),
39
Last(Arc<Expr>),
40
Mean(Arc<Expr>),
41
Implode(Arc<Expr>),
42
Count {
43
input: Arc<Expr>,
44
include_nulls: bool,
45
},
46
Quantile {
47
expr: Arc<Expr>,
48
quantile: Arc<Expr>,
49
method: QuantileMethod,
50
},
51
Sum(Arc<Expr>),
52
AggGroups(Arc<Expr>),
53
Std(Arc<Expr>, u8),
54
Var(Arc<Expr>, u8),
55
}
56
57
impl AsRef<Expr> for AggExpr {
58
fn as_ref(&self) -> &Expr {
59
use AggExpr::*;
60
match self {
61
Min { input, .. } => input,
62
Max { input, .. } => input,
63
Median(e) => e,
64
NUnique(e) => e,
65
First(e) => e,
66
Last(e) => e,
67
Mean(e) => e,
68
Implode(e) => e,
69
Count { input, .. } => input,
70
Quantile { expr, .. } => expr,
71
Sum(e) => e,
72
AggGroups(e) => e,
73
Std(e, _) => e,
74
Var(e, _) => e,
75
}
76
}
77
}
78
79
/// Expressions that can be used in various contexts.
80
///
81
/// Queries consist of multiple expressions.
82
/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using
83
/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info.
84
#[derive(Clone, PartialEq)]
85
#[must_use]
86
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
87
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
88
pub enum Expr {
89
Alias(Arc<Expr>, PlSmallStr),
90
Column(PlSmallStr),
91
Selector(Selector),
92
Literal(LiteralValue),
93
DataTypeFunction(DataTypeFunction),
94
BinaryExpr {
95
left: Arc<Expr>,
96
op: Operator,
97
right: Arc<Expr>,
98
},
99
Cast {
100
expr: Arc<Expr>,
101
dtype: DataTypeExpr,
102
options: CastOptions,
103
},
104
Sort {
105
expr: Arc<Expr>,
106
options: SortOptions,
107
},
108
Gather {
109
expr: Arc<Expr>,
110
idx: Arc<Expr>,
111
returns_scalar: bool,
112
},
113
SortBy {
114
expr: Arc<Expr>,
115
by: Vec<Expr>,
116
sort_options: SortMultipleOptions,
117
},
118
Agg(AggExpr),
119
/// A ternary operation
120
/// if true then "foo" else "bar"
121
Ternary {
122
predicate: Arc<Expr>,
123
truthy: Arc<Expr>,
124
falsy: Arc<Expr>,
125
},
126
Function {
127
/// function arguments
128
input: Vec<Expr>,
129
/// function to apply
130
function: FunctionExpr,
131
},
132
Explode {
133
input: Arc<Expr>,
134
skip_empty: bool,
135
},
136
Filter {
137
input: Arc<Expr>,
138
by: Arc<Expr>,
139
},
140
/// Polars flavored window functions.
141
Window {
142
/// Also has the input. i.e. avg("foo")
143
function: Arc<Expr>,
144
partition_by: Vec<Expr>,
145
order_by: Option<(Arc<Expr>, SortOptions)>,
146
options: WindowType,
147
},
148
Slice {
149
input: Arc<Expr>,
150
/// length is not yet known so we accept negative offsets
151
offset: Arc<Expr>,
152
length: Arc<Expr>,
153
},
154
/// Set root name as Alias
155
KeepName(Arc<Expr>),
156
Len,
157
#[cfg(feature = "dtype-struct")]
158
Field(Arc<[PlSmallStr]>),
159
AnonymousFunction {
160
/// function arguments
161
input: Vec<Expr>,
162
/// function to apply
163
function: OpaqueColumnUdf,
164
165
options: FunctionOptions,
166
/// used for formatting
167
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
168
fmt_str: Box<PlSmallStr>,
169
},
170
/// Evaluates the `evaluation` expression on the output of the `expr`.
171
///
172
/// Consequently, `expr` is an input and `evaluation` is not and needs a different schema.
173
Eval {
174
expr: Arc<Expr>,
175
evaluation: Arc<Expr>,
176
variant: EvalVariant,
177
},
178
SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
179
RenameAlias {
180
function: RenameAliasFn,
181
expr: Arc<Expr>,
182
},
183
}
184
185
#[derive(Clone)]
186
pub enum LazySerde<T: Clone> {
187
Deserialized(T),
188
Bytes(Bytes),
189
/// Named functions allow for serializing arbitrary Rust functions as long as both sides know
190
/// ahead of time which function it is. There is a registry of functions that both sides know
191
/// and every time we need serialize we serialize the function by name in the registry.
192
///
193
/// Used by cloud.
194
Named {
195
// Name and payload are used by the NamedRegistry
196
// To load the function `T` at runtime.
197
name: String,
198
payload: Option<Bytes>,
199
// Sometimes we need the function `T` before sending
200
// to a different machine, so optionally set it as well.
201
value: Option<T>,
202
},
203
}
204
205
impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
206
fn eq(&self, other: &Self) -> bool {
207
use LazySerde as L;
208
match (self, other) {
209
(L::Deserialized(a), L::Deserialized(b)) => a == b,
210
(L::Bytes(a), L::Bytes(b)) => {
211
std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len()
212
},
213
(
214
L::Named {
215
name: l,
216
payload: pl,
217
value: _,
218
},
219
L::Named {
220
name: r,
221
payload: pr,
222
value: _,
223
},
224
) => {
225
#[cfg(debug_assertions)]
226
{
227
if l == r {
228
assert_eq!(pl, pr, "name should point to unique payload")
229
}
230
}
231
_ = pl;
232
_ = pr;
233
l == r
234
},
235
_ => false,
236
}
237
}
238
}
239
240
impl<T: Clone> Debug for LazySerde<T> {
241
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
242
match self {
243
Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
244
Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
245
Self::Named {
246
name,
247
payload: _,
248
value: _,
249
} => write!(f, "lazy-serde<Named>: {name}"),
250
}
251
}
252
}
253
254
#[allow(clippy::derived_hash_with_manual_eq)]
255
impl Hash for Expr {
256
fn hash<H: Hasher>(&self, state: &mut H) {
257
let d = std::mem::discriminant(self);
258
d.hash(state);
259
match self {
260
Expr::Column(name) => name.hash(state),
261
// Expr::Columns(names) => names.hash(state),
262
// Expr::DtypeColumn(dtypes) => dtypes.hash(state),
263
// Expr::IndexColumn(indices) => indices.hash(state),
264
Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
265
Expr::Selector(s) => s.hash(state),
266
// Expr::Nth(v) => v.hash(state),
267
Expr::DataTypeFunction(v) => v.hash(state),
268
Expr::Filter { input, by } => {
269
input.hash(state);
270
by.hash(state);
271
},
272
Expr::BinaryExpr { left, op, right } => {
273
left.hash(state);
274
right.hash(state);
275
std::mem::discriminant(op).hash(state)
276
},
277
Expr::Cast {
278
expr,
279
dtype,
280
options: strict,
281
} => {
282
expr.hash(state);
283
dtype.hash(state);
284
strict.hash(state)
285
},
286
Expr::Sort { expr, options } => {
287
expr.hash(state);
288
options.hash(state);
289
},
290
Expr::Alias(input, name) => {
291
input.hash(state);
292
name.hash(state)
293
},
294
Expr::KeepName(input) => input.hash(state),
295
Expr::Ternary {
296
predicate,
297
truthy,
298
falsy,
299
} => {
300
predicate.hash(state);
301
truthy.hash(state);
302
falsy.hash(state);
303
},
304
Expr::Function { input, function } => {
305
input.hash(state);
306
std::mem::discriminant(function).hash(state);
307
},
308
Expr::Gather {
309
expr,
310
idx,
311
returns_scalar,
312
} => {
313
expr.hash(state);
314
idx.hash(state);
315
returns_scalar.hash(state);
316
},
317
// already hashed by discriminant
318
Expr::Len => {},
319
Expr::SortBy {
320
expr,
321
by,
322
sort_options,
323
} => {
324
expr.hash(state);
325
by.hash(state);
326
sort_options.hash(state);
327
},
328
Expr::Agg(input) => input.hash(state),
329
Expr::Explode { input, skip_empty } => {
330
skip_empty.hash(state);
331
input.hash(state)
332
},
333
Expr::Window {
334
function,
335
partition_by,
336
order_by,
337
options,
338
} => {
339
function.hash(state);
340
partition_by.hash(state);
341
order_by.hash(state);
342
options.hash(state);
343
},
344
Expr::Slice {
345
input,
346
offset,
347
length,
348
} => {
349
input.hash(state);
350
offset.hash(state);
351
length.hash(state);
352
},
353
// Expr::Exclude(input, excl) => {
354
// input.hash(state);
355
// excl.hash(state);
356
// },
357
Expr::RenameAlias { function, expr } => {
358
function.hash(state);
359
expr.hash(state);
360
},
361
Expr::AnonymousFunction {
362
input,
363
function: _,
364
options,
365
fmt_str,
366
} => {
367
input.hash(state);
368
options.hash(state);
369
fmt_str.hash(state);
370
},
371
Expr::Eval {
372
expr: input,
373
evaluation,
374
variant,
375
} => {
376
input.hash(state);
377
evaluation.hash(state);
378
variant.hash(state);
379
},
380
Expr::SubPlan(_, names) => names.hash(state),
381
#[cfg(feature = "dtype-struct")]
382
Expr::Field(names) => names.hash(state),
383
}
384
}
385
}
386
387
impl Eq for Expr {}
388
389
impl Default for Expr {
390
fn default() -> Self {
391
Expr::Literal(LiteralValue::Scalar(Scalar::default()))
392
}
393
}
394
395
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
396
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
397
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
398
pub enum Excluded {
399
Name(PlSmallStr),
400
Dtype(DataType),
401
}
402
403
impl Expr {
404
/// Get Field result of the expression. The schema is the input data.
405
pub fn to_field(&self, schema: &Schema) -> PolarsResult<Field> {
406
// this is not called much and the expression depth is typically shallow
407
let mut arena = Arena::with_capacity(5);
408
self.to_field_amortized(schema, &mut arena)
409
}
410
pub(crate) fn to_field_amortized(
411
&self,
412
schema: &Schema,
413
expr_arena: &mut Arena<AExpr>,
414
) -> PolarsResult<Field> {
415
let mut ctx = ExprToIRContext::new(expr_arena, schema);
416
ctx.allow_unknown = true;
417
let expr = to_expr_ir(self.clone(), &mut ctx)?;
418
let (node, output_name) = expr.into_inner();
419
let dtype = expr_arena.get(node).to_dtype(schema, expr_arena)?;
420
Ok(Field::new(output_name.into_inner().unwrap(), dtype))
421
}
422
423
pub fn into_selector(self) -> Option<Selector> {
424
match self {
425
Expr::Column(name) => Some(Selector::ByName {
426
names: [name].into(),
427
strict: true,
428
}),
429
Expr::Selector(selector) => Some(selector),
430
_ => None,
431
}
432
}
433
434
pub fn try_into_selector(self) -> PolarsResult<Selector> {
435
match self {
436
Expr::Column(name) => Ok(Selector::ByName {
437
names: [name].into(),
438
strict: true,
439
}),
440
Expr::Selector(selector) => Ok(selector),
441
expr => Err(polars_err!(InvalidOperation: "cannot turn `{expr}` into selector")),
442
}
443
}
444
445
/// Extract a constant usize from an expression.
446
pub fn extract_usize(&self) -> PolarsResult<usize> {
447
match self {
448
Expr::Literal(n) => n.extract_usize(),
449
Expr::Cast { expr, dtype, .. } => {
450
// lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal.
451
if dtype.as_literal().is_some_and(|dt| dt.is_integer()) {
452
expr.extract_usize()
453
} else {
454
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
455
}
456
},
457
_ => {
458
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
459
},
460
}
461
}
462
463
pub fn extract_i64(&self) -> PolarsResult<i64> {
464
match self {
465
Expr::Literal(n) => n.extract_i64(),
466
Expr::BinaryExpr { left, op, right } => match op {
467
Operator::Minus => {
468
let left = left.extract_i64()?;
469
let right = right.extract_i64()?;
470
Ok(left - right)
471
},
472
_ => unreachable!(),
473
},
474
Expr::Cast { expr, dtype, .. } => {
475
if dtype.as_literal().is_some_and(|dt| dt.is_integer()) {
476
expr.extract_i64()
477
} else {
478
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
479
}
480
},
481
_ => {
482
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
483
},
484
}
485
}
486
487
#[inline]
488
pub fn map_unary(self, function: impl Into<FunctionExpr>) -> Self {
489
Expr::n_ary(function, vec![self])
490
}
491
#[inline]
492
pub fn map_binary(self, function: impl Into<FunctionExpr>, rhs: Self) -> Self {
493
Expr::n_ary(function, vec![self, rhs])
494
}
495
496
#[inline]
497
pub fn map_ternary(self, function: impl Into<FunctionExpr>, arg1: Expr, arg2: Expr) -> Expr {
498
Expr::n_ary(function, vec![self, arg1, arg2])
499
}
500
501
#[inline]
502
pub fn try_map_n_ary(
503
self,
504
function: impl Into<FunctionExpr>,
505
exprs: impl IntoIterator<Item = PolarsResult<Expr>>,
506
) -> PolarsResult<Expr> {
507
let exprs = exprs.into_iter();
508
let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
509
input.push(self);
510
for e in exprs {
511
input.push(e?);
512
}
513
Ok(Expr::n_ary(function, input))
514
}
515
516
#[inline]
517
pub fn map_n_ary(
518
self,
519
function: impl Into<FunctionExpr>,
520
exprs: impl IntoIterator<Item = Expr>,
521
) -> Expr {
522
let exprs = exprs.into_iter();
523
let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
524
input.push(self);
525
input.extend(exprs);
526
Expr::n_ary(function, input)
527
}
528
529
#[inline]
530
pub fn n_ary(function: impl Into<FunctionExpr>, input: Vec<Expr>) -> Expr {
531
let function = function.into();
532
Expr::Function { input, function }
533
}
534
}
535
536
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
537
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
538
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
539
pub enum EvalVariant {
540
/// `list.eval`
541
List,
542
543
/// `cumulative_eval`
544
Cumulative { min_samples: usize },
545
}
546
547
impl EvalVariant {
548
pub fn to_name(&self) -> &'static str {
549
match self {
550
Self::List => "list.eval",
551
Self::Cumulative { min_samples: _ } => "cumulative_eval",
552
}
553
}
554
555
/// Get the `DataType` of the `pl.element()` value.
556
pub fn element_dtype<'a>(&self, dtype: &'a DataType) -> PolarsResult<&'a DataType> {
557
match (self, dtype) {
558
(Self::List, DataType::List(inner)) => Ok(inner.as_ref()),
559
(Self::Cumulative { min_samples: _ }, dt) => Ok(dt),
560
_ => polars_bail!(op = self.to_name(), dtype),
561
}
562
}
563
564
pub fn is_elementwise(&self) -> bool {
565
match self {
566
EvalVariant::List => true,
567
EvalVariant::Cumulative { min_samples: _ } => false,
568
}
569
}
570
571
pub fn is_row_separable(&self) -> bool {
572
match self {
573
EvalVariant::List => true,
574
EvalVariant::Cumulative { min_samples: _ } => false,
575
}
576
}
577
578
pub fn is_length_preserving(&self) -> bool {
579
match self {
580
EvalVariant::List | EvalVariant::Cumulative { .. } => true,
581
}
582
}
583
}
584
585
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
586
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
587
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
588
pub enum Operator {
589
Eq,
590
EqValidity,
591
NotEq,
592
NotEqValidity,
593
Lt,
594
LtEq,
595
Gt,
596
GtEq,
597
Plus,
598
Minus,
599
Multiply,
600
Divide,
601
TrueDivide,
602
FloorDivide,
603
Modulus,
604
And,
605
Or,
606
Xor,
607
LogicalAnd,
608
LogicalOr,
609
}
610
611
impl Display for Operator {
612
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
613
use Operator::*;
614
let tkn = match self {
615
Eq => "==",
616
EqValidity => "==v",
617
NotEq => "!=",
618
NotEqValidity => "!=v",
619
Lt => "<",
620
LtEq => "<=",
621
Gt => ">",
622
GtEq => ">=",
623
Plus => "+",
624
Minus => "-",
625
Multiply => "*",
626
Divide => "//",
627
TrueDivide => "/",
628
FloorDivide => "floor_div",
629
Modulus => "%",
630
And | LogicalAnd => "&",
631
Or | LogicalOr => "|",
632
Xor => "^",
633
};
634
write!(f, "{tkn}")
635
}
636
}
637
638
impl Operator {
639
pub fn is_comparison(&self) -> bool {
640
matches!(
641
self,
642
Self::Eq
643
| Self::NotEq
644
| Self::Lt
645
| Self::LtEq
646
| Self::Gt
647
| Self::GtEq
648
| Self::EqValidity
649
| Self::NotEqValidity
650
)
651
}
652
653
pub fn is_bitwise(&self) -> bool {
654
matches!(self, Self::And | Self::Or | Self::Xor)
655
}
656
657
pub fn is_comparison_or_bitwise(&self) -> bool {
658
self.is_comparison() || self.is_bitwise()
659
}
660
661
pub fn swap_operands(self) -> Self {
662
match self {
663
Operator::Eq => Operator::Eq,
664
Operator::Gt => Operator::Lt,
665
Operator::GtEq => Operator::LtEq,
666
Operator::LtEq => Operator::GtEq,
667
Operator::Or => Operator::Or,
668
Operator::LogicalAnd => Operator::LogicalAnd,
669
Operator::LogicalOr => Operator::LogicalOr,
670
Operator::Xor => Operator::Xor,
671
Operator::NotEq => Operator::NotEq,
672
Operator::EqValidity => Operator::EqValidity,
673
Operator::NotEqValidity => Operator::NotEqValidity,
674
Operator::Divide => Operator::Multiply,
675
Operator::Multiply => Operator::Divide,
676
Operator::And => Operator::And,
677
Operator::Plus => Operator::Minus,
678
Operator::Minus => Operator::Plus,
679
Operator::Lt => Operator::Gt,
680
_ => unimplemented!(),
681
}
682
}
683
684
pub fn is_arithmetic(&self) -> bool {
685
!(self.is_comparison_or_bitwise())
686
}
687
}
688
689
#[derive(Clone, PartialEq, Eq, Hash)]
690
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
691
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
692
pub enum RenameAliasFn {
693
Prefix(PlSmallStr),
694
Suffix(PlSmallStr),
695
ToLowercase,
696
ToUppercase,
697
#[cfg(feature = "python")]
698
Python(SpecialEq<Arc<polars_utils::python_function::PythonObject>>),
699
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
700
Rust(SpecialEq<Arc<RenameAliasRustFn>>),
701
}
702
703
impl RenameAliasFn {
704
pub fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
705
let out = match self {
706
Self::Prefix(prefix) => format_pl_smallstr!("{prefix}{name}"),
707
Self::Suffix(suffix) => format_pl_smallstr!("{name}{suffix}"),
708
Self::ToLowercase => PlSmallStr::from_string(name.to_lowercase()),
709
Self::ToUppercase => PlSmallStr::from_string(name.to_uppercase()),
710
#[cfg(feature = "python")]
711
Self::Python(lambda) => {
712
let name = name.as_str();
713
pyo3::marker::Python::with_gil(|py| {
714
let out: PlSmallStr = lambda
715
.call1(py, (name,))?
716
.extract::<std::borrow::Cow<str>>(py)?
717
.as_ref()
718
.into();
719
pyo3::PyResult::<_>::Ok(out)
720
}).map_err(|e| polars_err!(ComputeError: "Python function in 'name.map' produced an error: {e}."))?
721
},
722
Self::Rust(f) => f(name)?,
723
};
724
Ok(out)
725
}
726
}
727
728
pub type RenameAliasRustFn =
729
dyn Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + 'static + Send + Sync;
730
731