Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/expr/mod.rs
8327 views
1
pub mod anonymous;
2
mod datatype_fn;
3
use std::fmt::{Debug, Display, Formatter};
4
use std::hash::{Hash, Hasher};
5
6
pub use anonymous::*;
7
use bytes::Bytes;
8
pub use datatype_fn::*;
9
use polars_compute::rolling::QuantileMethod;
10
use polars_core::chunked_array::cast::CastOptions;
11
use polars_core::error::feature_gated;
12
use polars_core::prelude::*;
13
use polars_utils::format_pl_smallstr;
14
#[cfg(feature = "serde")]
15
use serde::{Deserialize, Serialize};
16
17
use super::datatype_expr::DataTypeExpr;
18
use crate::prelude::*;
19
20
#[derive(PartialEq, Clone, Hash)]
21
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
23
pub enum AggExpr {
24
Min {
25
input: Arc<Expr>,
26
propagate_nans: bool,
27
},
28
Max {
29
input: Arc<Expr>,
30
propagate_nans: bool,
31
},
32
Median(Arc<Expr>),
33
NUnique(Arc<Expr>),
34
First(Arc<Expr>),
35
FirstNonNull(Arc<Expr>),
36
Last(Arc<Expr>),
37
LastNonNull(Arc<Expr>),
38
Item {
39
input: Arc<Expr>,
40
/// Give a missing value if there are no values.
41
allow_empty: bool,
42
},
43
Mean(Arc<Expr>),
44
Implode(Arc<Expr>),
45
Count {
46
input: Arc<Expr>,
47
include_nulls: bool,
48
},
49
Quantile {
50
expr: Arc<Expr>,
51
quantile: Arc<Expr>,
52
method: QuantileMethod,
53
},
54
Sum(Arc<Expr>),
55
AggGroups(Arc<Expr>),
56
Std(Arc<Expr>, u8),
57
Var(Arc<Expr>, u8),
58
}
59
60
impl AsRef<Expr> for AggExpr {
61
fn as_ref(&self) -> &Expr {
62
use AggExpr::*;
63
match self {
64
Min { input, .. } => input,
65
Max { input, .. } => input,
66
Median(e) => e,
67
NUnique(e) => e,
68
First(e) => e,
69
FirstNonNull(e) => e,
70
Last(e) => e,
71
LastNonNull(e) => e,
72
Item { input, .. } => input,
73
Mean(e) => e,
74
Implode(e) => e,
75
Count { input, .. } => input,
76
Quantile { expr, .. } => expr,
77
Sum(e) => e,
78
AggGroups(e) => e,
79
Std(e, _) => e,
80
Var(e, _) => e,
81
}
82
}
83
}
84
85
/// Expressions that can be used in various contexts.
86
///
87
/// Queries consist of multiple expressions.
88
/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using
89
/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info.
90
#[derive(Clone, PartialEq)]
91
#[must_use]
92
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
94
pub enum Expr {
95
/// Values in a `eval` context.
96
///
97
/// Equivalent of `pl.element()`.
98
Element,
99
Alias(Arc<Expr>, PlSmallStr),
100
Column(PlSmallStr),
101
Selector(Selector),
102
Literal(LiteralValue),
103
DataTypeFunction(DataTypeFunction),
104
BinaryExpr {
105
left: Arc<Expr>,
106
op: Operator,
107
right: Arc<Expr>,
108
},
109
Cast {
110
expr: Arc<Expr>,
111
dtype: DataTypeExpr,
112
options: CastOptions,
113
},
114
Sort {
115
expr: Arc<Expr>,
116
options: SortOptions,
117
},
118
Gather {
119
expr: Arc<Expr>,
120
idx: Arc<Expr>,
121
returns_scalar: bool,
122
null_on_oob: bool,
123
},
124
SortBy {
125
expr: Arc<Expr>,
126
by: Vec<Expr>,
127
sort_options: SortMultipleOptions,
128
},
129
Agg(AggExpr),
130
/// A ternary operation
131
/// if true then "foo" else "bar"
132
Ternary {
133
predicate: Arc<Expr>,
134
truthy: Arc<Expr>,
135
falsy: Arc<Expr>,
136
},
137
Function {
138
/// function arguments
139
input: Vec<Expr>,
140
/// function to apply
141
function: FunctionExpr,
142
},
143
Explode {
144
input: Arc<Expr>,
145
options: ExplodeOptions,
146
},
147
Filter {
148
input: Arc<Expr>,
149
by: Arc<Expr>,
150
},
151
/// Polars flavored window functions.
152
Over {
153
/// Also has the input. i.e. avg("foo")
154
function: Arc<Expr>,
155
partition_by: Vec<Expr>,
156
order_by: Option<(Arc<Expr>, SortOptions)>,
157
mapping: WindowMapping,
158
},
159
#[cfg(feature = "dynamic_group_by")]
160
Rolling {
161
function: Arc<Expr>,
162
index_column: Arc<Expr>,
163
period: Duration,
164
offset: Duration,
165
closed_window: ClosedWindow,
166
},
167
Slice {
168
input: Arc<Expr>,
169
/// length is not yet known so we accept negative offsets
170
offset: Arc<Expr>,
171
length: Arc<Expr>,
172
},
173
/// Set root name as Alias
174
KeepName(Arc<Expr>),
175
Len,
176
#[cfg(feature = "dtype-struct")]
177
Field(Arc<[PlSmallStr]>),
178
AnonymousFunction {
179
/// function arguments
180
input: Vec<Expr>,
181
/// function to apply
182
function: OpaqueColumnUdf,
183
184
options: FunctionOptions,
185
/// used for formatting
186
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
187
fmt_str: Box<PlSmallStr>,
188
},
189
/// Evaluates the `evaluation` expression on the output of the `expr`.
190
///
191
/// Consequently, `expr` is an input and `evaluation` is not and needs a different schema.
192
Eval {
193
expr: Arc<Expr>,
194
evaluation: Arc<Expr>,
195
variant: EvalVariant,
196
},
197
/// Evaluates the `evaluation` expressions on the output of the `expr`.
198
///
199
/// Consequently, `expr` is an input and `evaluation` uses an extended schema that includes this input.
200
#[cfg(feature = "dtype-struct")]
201
StructEval {
202
expr: Arc<Expr>,
203
evaluation: Vec<Expr>,
204
},
205
/// SQL SubQueries
206
SubPlan(SpecialEq<Arc<DslPlan>>, Vec<PlSmallStr>),
207
RenameAlias {
208
function: RenameAliasFn,
209
expr: Arc<Expr>,
210
},
211
/// Not a real expression. This is meant
212
/// as catch-all for IR expressions that
213
/// are not supported by DSL.
214
Display {
215
inputs: Vec<Expr>,
216
fmt_str: Box<PlSmallStr>,
217
},
218
}
219
220
#[derive(Clone)]
221
pub enum LazySerde<T: Clone> {
222
Deserialized(T),
223
Bytes(Bytes),
224
/// Named functions allow for serializing arbitrary Rust functions as long as both sides know
225
/// ahead of time which function it is. There is a registry of functions that both sides know
226
/// and every time we need serialize we serialize the function by name in the registry.
227
///
228
/// Used by cloud.
229
Named {
230
// Name and payload are used by the NamedRegistry
231
// To load the function `T` at runtime.
232
name: String,
233
payload: Option<Bytes>,
234
// Sometimes we need the function `T` before sending
235
// to a different machine, so optionally set it as well.
236
value: Option<T>,
237
},
238
}
239
240
impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
241
fn eq(&self, other: &Self) -> bool {
242
use LazySerde as L;
243
match (self, other) {
244
(L::Deserialized(a), L::Deserialized(b)) => a == b,
245
(L::Bytes(a), L::Bytes(b)) => {
246
std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len()
247
},
248
(
249
L::Named {
250
name: l,
251
payload: pl,
252
value: _,
253
},
254
L::Named {
255
name: r,
256
payload: pr,
257
value: _,
258
},
259
) => {
260
#[cfg(debug_assertions)]
261
{
262
if l == r {
263
assert_eq!(pl, pr, "name should point to unique payload")
264
}
265
}
266
_ = pl;
267
_ = pr;
268
l == r
269
},
270
_ => false,
271
}
272
}
273
}
274
275
impl<T: Clone> Debug for LazySerde<T> {
276
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
277
match self {
278
Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
279
Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
280
Self::Named {
281
name,
282
payload: _,
283
value: _,
284
} => write!(f, "lazy-serde<Named>: {name}"),
285
}
286
}
287
}
288
289
#[allow(clippy::derived_hash_with_manual_eq)]
290
impl Hash for Expr {
291
fn hash<H: Hasher>(&self, state: &mut H) {
292
let d = std::mem::discriminant(self);
293
d.hash(state);
294
match self {
295
Expr::Column(name) => name.hash(state),
296
// Expr::Columns(names) => names.hash(state),
297
// Expr::DtypeColumn(dtypes) => dtypes.hash(state),
298
// Expr::IndexColumn(indices) => indices.hash(state),
299
Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
300
Expr::Selector(s) => s.hash(state),
301
// Expr::Nth(v) => v.hash(state),
302
Expr::DataTypeFunction(v) => v.hash(state),
303
Expr::Filter { input, by } => {
304
input.hash(state);
305
by.hash(state);
306
},
307
Expr::BinaryExpr { left, op, right } => {
308
left.hash(state);
309
right.hash(state);
310
std::mem::discriminant(op).hash(state)
311
},
312
Expr::Cast {
313
expr,
314
dtype,
315
options: strict,
316
} => {
317
expr.hash(state);
318
dtype.hash(state);
319
strict.hash(state)
320
},
321
Expr::Sort { expr, options } => {
322
expr.hash(state);
323
options.hash(state);
324
},
325
Expr::Alias(input, name) => {
326
input.hash(state);
327
name.hash(state)
328
},
329
Expr::KeepName(input) => input.hash(state),
330
Expr::Ternary {
331
predicate,
332
truthy,
333
falsy,
334
} => {
335
predicate.hash(state);
336
truthy.hash(state);
337
falsy.hash(state);
338
},
339
Expr::Function { input, function } => {
340
input.hash(state);
341
std::mem::discriminant(function).hash(state);
342
},
343
Expr::Gather {
344
expr,
345
idx,
346
returns_scalar,
347
null_on_oob,
348
} => {
349
expr.hash(state);
350
idx.hash(state);
351
returns_scalar.hash(state);
352
null_on_oob.hash(state);
353
},
354
// already hashed by discriminant
355
Expr::Element | Expr::Len => {},
356
Expr::SortBy {
357
expr,
358
by,
359
sort_options,
360
} => {
361
expr.hash(state);
362
by.hash(state);
363
sort_options.hash(state);
364
},
365
Expr::Agg(input) => input.hash(state),
366
Expr::Explode { input, options } => {
367
options.hash(state);
368
input.hash(state)
369
},
370
#[cfg(feature = "dynamic_group_by")]
371
Expr::Rolling {
372
function,
373
index_column,
374
period,
375
offset,
376
closed_window,
377
} => {
378
function.hash(state);
379
index_column.hash(state);
380
period.hash(state);
381
offset.hash(state);
382
closed_window.hash(state);
383
},
384
Expr::Over {
385
function,
386
partition_by,
387
order_by,
388
mapping,
389
} => {
390
function.hash(state);
391
partition_by.hash(state);
392
order_by.hash(state);
393
mapping.hash(state);
394
},
395
Expr::Slice {
396
input,
397
offset,
398
length,
399
} => {
400
input.hash(state);
401
offset.hash(state);
402
length.hash(state);
403
},
404
Expr::RenameAlias { function, expr } => {
405
function.hash(state);
406
expr.hash(state);
407
},
408
Expr::Display { inputs, fmt_str } => {
409
inputs.hash(state);
410
fmt_str.hash(state);
411
},
412
Expr::AnonymousFunction {
413
input,
414
function: _,
415
options,
416
fmt_str,
417
} => {
418
input.hash(state);
419
options.hash(state);
420
fmt_str.hash(state);
421
},
422
Expr::Eval {
423
expr: input,
424
evaluation,
425
variant,
426
} => {
427
input.hash(state);
428
evaluation.hash(state);
429
variant.hash(state);
430
},
431
#[cfg(feature = "dtype-struct")]
432
Expr::StructEval {
433
expr: input,
434
evaluation,
435
} => {
436
input.hash(state);
437
evaluation.hash(state);
438
},
439
Expr::SubPlan(_, names) => names.hash(state),
440
#[cfg(feature = "dtype-struct")]
441
Expr::Field(names) => names.hash(state),
442
}
443
}
444
}
445
446
impl Eq for Expr {}
447
448
impl Default for Expr {
449
fn default() -> Self {
450
Expr::Literal(LiteralValue::Scalar(Scalar::default()))
451
}
452
}
453
454
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
455
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
456
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
457
pub enum Excluded {
458
Name(PlSmallStr),
459
Dtype(DataType),
460
}
461
462
impl Expr {
463
/// Get Field result of the expression. The schema is the input data.
464
pub fn to_field(&self, schema: &Schema) -> PolarsResult<Field> {
465
// this is not called much and the expression depth is typically shallow
466
let mut arena = Arena::with_capacity(5);
467
self.to_field_amortized(schema, &mut arena)
468
}
469
pub(crate) fn to_field_amortized(
470
&self,
471
schema: &Schema,
472
expr_arena: &mut Arena<AExpr>,
473
) -> PolarsResult<Field> {
474
let mut ctx = ExprToIRContext::new_with_fields(expr_arena, schema);
475
ctx.allow_unknown = true;
476
let expr = to_expr_ir(self.clone(), &mut ctx)?;
477
let (node, output_name) = expr.into_inner();
478
let dtype = expr_arena
479
.get(node)
480
.to_dtype(&ToFieldContext::new(expr_arena, schema))?;
481
Ok(Field::new(output_name.into_inner().unwrap(), dtype))
482
}
483
484
pub fn into_selector(self) -> Option<Selector> {
485
match self {
486
Expr::Column(name) => Some(Selector::ByName {
487
names: [name].into(),
488
strict: true,
489
}),
490
Expr::Selector(selector) => Some(selector),
491
_ => None,
492
}
493
}
494
495
pub fn try_into_selector(self) -> PolarsResult<Selector> {
496
match self {
497
Expr::Column(name) => Ok(Selector::ByName {
498
names: [name].into(),
499
strict: true,
500
}),
501
Expr::Selector(selector) => Ok(selector),
502
expr => Err(polars_err!(InvalidOperation: "cannot turn `{expr}` into selector")),
503
}
504
}
505
506
/// Extract a constant usize from an expression.
507
pub fn extract_usize(&self) -> PolarsResult<usize> {
508
match self {
509
Expr::Literal(n) => n.extract_usize(),
510
Expr::Cast { expr, dtype, .. } => {
511
// lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal.
512
if dtype.as_literal().is_some_and(|dt| dt.is_integer()) {
513
expr.extract_usize()
514
} else {
515
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
516
}
517
},
518
_ => {
519
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
520
},
521
}
522
}
523
524
pub fn extract_i64(&self) -> PolarsResult<i64> {
525
match self {
526
Expr::Literal(n) => n.extract_i64(),
527
Expr::BinaryExpr { left, op, right } => match op {
528
Operator::Minus => {
529
let left = left.extract_i64()?;
530
let right = right.extract_i64()?;
531
Ok(left - right)
532
},
533
_ => unreachable!(),
534
},
535
Expr::Cast { expr, dtype, .. } => {
536
if dtype.as_literal().is_some_and(|dt| dt.is_integer()) {
537
expr.extract_i64()
538
} else {
539
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
540
}
541
},
542
_ => {
543
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
544
},
545
}
546
}
547
548
#[inline]
549
pub fn map_unary(self, function: impl Into<FunctionExpr>) -> Self {
550
Expr::n_ary(function, vec![self])
551
}
552
#[inline]
553
pub fn map_binary(self, function: impl Into<FunctionExpr>, rhs: Self) -> Self {
554
Expr::n_ary(function, vec![self, rhs])
555
}
556
557
#[inline]
558
pub fn map_ternary(self, function: impl Into<FunctionExpr>, arg1: Expr, arg2: Expr) -> Expr {
559
Expr::n_ary(function, vec![self, arg1, arg2])
560
}
561
562
#[inline]
563
pub fn try_map_n_ary(
564
self,
565
function: impl Into<FunctionExpr>,
566
exprs: impl IntoIterator<Item = PolarsResult<Expr>>,
567
) -> PolarsResult<Expr> {
568
let exprs = exprs.into_iter();
569
let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
570
input.push(self);
571
for e in exprs {
572
input.push(e?);
573
}
574
Ok(Expr::n_ary(function, input))
575
}
576
577
#[inline]
578
pub fn map_n_ary(
579
self,
580
function: impl Into<FunctionExpr>,
581
exprs: impl IntoIterator<Item = Expr>,
582
) -> Expr {
583
let exprs = exprs.into_iter();
584
let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
585
input.push(self);
586
input.extend(exprs);
587
Expr::n_ary(function, input)
588
}
589
590
#[inline]
591
pub fn n_ary(function: impl Into<FunctionExpr>, input: Vec<Expr>) -> Expr {
592
let function = function.into();
593
Expr::Function { input, function }
594
}
595
}
596
597
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
598
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
599
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
600
pub enum EvalVariant {
601
/// `list.eval`
602
List,
603
/// `list.agg`
604
ListAgg,
605
606
/// `array.eval`
607
Array {
608
/// If set to true, evaluation can output variable amount of items and output datatype will
609
/// be `List`.
610
as_list: bool,
611
},
612
/// `arr.agg`
613
ArrayAgg,
614
615
/// `cumulative_eval`
616
Cumulative { min_samples: usize },
617
}
618
619
impl EvalVariant {
620
pub fn to_name(&self) -> &'static str {
621
match self {
622
Self::List => "list.eval",
623
Self::ListAgg => "list.agg",
624
Self::Array { .. } => "array.eval",
625
Self::ArrayAgg => "array.agg",
626
Self::Cumulative { min_samples: _ } => "cumulative_eval",
627
}
628
}
629
630
/// Get the `DataType` of the `pl.element()` value.
631
pub fn element_dtype<'a>(&self, dtype: &'a DataType) -> PolarsResult<&'a DataType> {
632
match (self, dtype) {
633
(Self::List | Self::ListAgg, DataType::List(inner)) => Ok(inner.as_ref()),
634
#[cfg(feature = "dtype-array")]
635
(Self::Array { .. } | Self::ArrayAgg, DataType::Array(inner, _)) => Ok(inner.as_ref()),
636
(Self::Cumulative { min_samples: _ }, dt) => Ok(dt),
637
_ => polars_bail!(op = self.to_name(), dtype),
638
}
639
}
640
641
/// Get the output datatype from the output element datatype
642
pub fn output_dtype(
643
&self,
644
dtype: &'_ DataType,
645
output_element_dtype: DataType,
646
eval_is_scalar: bool,
647
) -> PolarsResult<DataType> {
648
match (self, dtype) {
649
(Self::List, DataType::List(_)) => Ok(DataType::List(Box::new(output_element_dtype))),
650
(Self::ListAgg, DataType::List(_)) => {
651
if eval_is_scalar {
652
Ok(output_element_dtype)
653
} else {
654
Ok(DataType::List(Box::new(output_element_dtype)))
655
}
656
},
657
#[cfg(feature = "dtype-array")]
658
(Self::Array { as_list: false }, DataType::Array(_, width)) => {
659
Ok(DataType::Array(Box::new(output_element_dtype), *width))
660
},
661
#[cfg(feature = "dtype-array")]
662
(Self::Array { as_list: true }, DataType::Array(_, _)) => {
663
Ok(DataType::List(Box::new(output_element_dtype)))
664
},
665
#[cfg(feature = "dtype-array")]
666
(Self::ArrayAgg, DataType::Array(_, _)) => {
667
if eval_is_scalar {
668
Ok(output_element_dtype)
669
} else {
670
Ok(DataType::List(Box::new(output_element_dtype)))
671
}
672
},
673
(Self::Cumulative { min_samples: _ }, _) => Ok(output_element_dtype),
674
_ => polars_bail!(op = self.to_name(), dtype),
675
}
676
}
677
678
pub fn is_elementwise(&self) -> bool {
679
match self {
680
EvalVariant::List | EvalVariant::ListAgg => true,
681
EvalVariant::Array { .. } | EvalVariant::ArrayAgg => true,
682
EvalVariant::Cumulative { min_samples: _ } => false,
683
}
684
}
685
686
pub fn is_row_separable(&self) -> bool {
687
match self {
688
EvalVariant::List | EvalVariant::ListAgg => true,
689
EvalVariant::Array { .. } | EvalVariant::ArrayAgg => true,
690
EvalVariant::Cumulative { min_samples: _ } => false,
691
}
692
}
693
694
pub fn is_length_preserving(&self) -> bool {
695
match self {
696
EvalVariant::List
697
| EvalVariant::ListAgg
698
| EvalVariant::Array { .. }
699
| EvalVariant::ArrayAgg
700
| EvalVariant::Cumulative { .. } => true,
701
}
702
}
703
}
704
705
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
706
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
707
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
708
pub enum Operator {
709
Eq,
710
EqValidity,
711
NotEq,
712
NotEqValidity,
713
Lt,
714
LtEq,
715
Gt,
716
GtEq,
717
Plus,
718
Minus,
719
Multiply,
720
/// Rust division semantics, this is what Rust interface `/` fispatches to
721
RustDivide,
722
/// Python division semantics, converting to floats. This is what python `/` operator dispatches to
723
TrueDivide,
724
/// Floor division semantics, this is what python `//` dispatches to
725
FloorDivide,
726
Modulus,
727
And,
728
Or,
729
Xor,
730
LogicalAnd,
731
LogicalOr,
732
}
733
734
impl Display for Operator {
735
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
736
use Operator::*;
737
let tkn = match self {
738
Eq => "==",
739
EqValidity => "==v",
740
NotEq => "!=",
741
NotEqValidity => "!=v",
742
Lt => "<",
743
LtEq => "<=",
744
Gt => ">",
745
GtEq => ">=",
746
Plus => "+",
747
Minus => "-",
748
Multiply => "*",
749
RustDivide => "rust_div",
750
TrueDivide => "/",
751
FloorDivide => "//",
752
Modulus => "%",
753
And | LogicalAnd => "&",
754
Or | LogicalOr => "|",
755
Xor => "^",
756
};
757
write!(f, "{tkn}")
758
}
759
}
760
761
impl Operator {
762
pub fn is_comparison(&self) -> bool {
763
matches!(
764
self,
765
Self::Eq
766
| Self::NotEq
767
| Self::Lt
768
| Self::LtEq
769
| Self::Gt
770
| Self::GtEq
771
| Self::EqValidity
772
| Self::NotEqValidity
773
)
774
}
775
776
pub fn is_bitwise(&self) -> bool {
777
matches!(self, Self::And | Self::Or | Self::Xor)
778
}
779
780
pub fn is_comparison_or_bitwise(&self) -> bool {
781
self.is_comparison() || self.is_bitwise()
782
}
783
784
pub fn swap_operands(self) -> Self {
785
match self {
786
Operator::Eq => Operator::Eq,
787
Operator::Gt => Operator::Lt,
788
Operator::GtEq => Operator::LtEq,
789
Operator::LtEq => Operator::GtEq,
790
Operator::Or => Operator::Or,
791
Operator::LogicalAnd => Operator::LogicalAnd,
792
Operator::LogicalOr => Operator::LogicalOr,
793
Operator::Xor => Operator::Xor,
794
Operator::NotEq => Operator::NotEq,
795
Operator::EqValidity => Operator::EqValidity,
796
Operator::NotEqValidity => Operator::NotEqValidity,
797
// Operator::Divide requires modifying the right operand: left / right == 1/right * left
798
Operator::RustDivide => unimplemented!(),
799
Operator::Multiply => Operator::Multiply,
800
Operator::And => Operator::And,
801
Operator::Plus => Operator::Plus,
802
// Operator::Minus requires modifying the right operand: left - right == -right + left
803
Operator::Minus => unimplemented!(),
804
Operator::Lt => Operator::Gt,
805
_ => unimplemented!(),
806
}
807
}
808
809
pub fn is_arithmetic(&self) -> bool {
810
!(self.is_comparison_or_bitwise())
811
}
812
}
813
814
#[derive(Clone, PartialEq, Eq, Hash)]
815
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
816
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
817
pub enum RenameAliasFn {
818
Prefix(PlSmallStr),
819
Suffix(PlSmallStr),
820
ToLowercase,
821
ToUppercase,
822
Map(PlanCallback<PlSmallStr, PlSmallStr>),
823
Replace {
824
pattern: PlSmallStr,
825
value: PlSmallStr,
826
literal: bool,
827
},
828
}
829
830
impl RenameAliasFn {
831
pub fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
832
let out = match self {
833
Self::Prefix(prefix) => format_pl_smallstr!("{prefix}{name}"),
834
Self::Suffix(suffix) => format_pl_smallstr!("{name}{suffix}"),
835
Self::ToLowercase => PlSmallStr::from_string(name.to_lowercase()),
836
Self::ToUppercase => PlSmallStr::from_string(name.to_uppercase()),
837
Self::Map(f) => f.call(name.clone())?,
838
Self::Replace {
839
pattern,
840
value,
841
literal,
842
} => {
843
if *literal {
844
name.replace(pattern.as_str(), value.as_str()).into()
845
} else {
846
feature_gated!("regex", {
847
let rx = polars_utils::regex_cache::compile_regex(pattern)?;
848
rx.replace_all(name, value.as_str()).into()
849
})
850
}
851
},
852
};
853
Ok(out)
854
}
855
}
856
857
pub type RenameAliasRustFn =
858
dyn Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + 'static + Send + Sync;
859
860