Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/expr/mod.rs
8424 views
1
pub mod anonymous;
2
mod datatype_fn;
3
use std::fmt::{Debug, Display, Formatter};
4
use std::hash::{Hash, Hasher};
5
6
pub use anonymous::*;
7
use bytes::Bytes;
8
pub use datatype_fn::*;
9
use polars_compute::rolling::QuantileMethod;
10
use polars_core::chunked_array::cast::CastOptions;
11
use polars_core::error::feature_gated;
12
use polars_core::prelude::*;
13
use polars_utils::format_pl_smallstr;
14
#[cfg(feature = "serde")]
15
use serde::{Deserialize, Serialize};
16
17
use super::datatype_expr::DataTypeExpr;
18
use crate::prelude::*;
19
20
#[derive(PartialEq, Clone, Hash)]
21
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
23
pub enum AggExpr {
24
Min {
25
input: Arc<Expr>,
26
propagate_nans: bool,
27
},
28
Max {
29
input: Arc<Expr>,
30
propagate_nans: bool,
31
},
32
Median(Arc<Expr>),
33
NUnique(Arc<Expr>),
34
First(Arc<Expr>),
35
FirstNonNull(Arc<Expr>),
36
Last(Arc<Expr>),
37
LastNonNull(Arc<Expr>),
38
Item {
39
input: Arc<Expr>,
40
/// Give a missing value if there are no values.
41
allow_empty: bool,
42
},
43
Mean(Arc<Expr>),
44
Implode(Arc<Expr>),
45
Count {
46
input: Arc<Expr>,
47
include_nulls: bool,
48
},
49
Quantile {
50
expr: Arc<Expr>,
51
quantile: Arc<Expr>,
52
method: QuantileMethod,
53
},
54
Sum(Arc<Expr>),
55
AggGroups(Arc<Expr>),
56
Std(Arc<Expr>, u8),
57
Var(Arc<Expr>, u8),
58
}
59
60
impl AsRef<Expr> for AggExpr {
61
fn as_ref(&self) -> &Expr {
62
use AggExpr::*;
63
match self {
64
Min { input, .. } => input,
65
Max { input, .. } => input,
66
Median(e) => e,
67
NUnique(e) => e,
68
First(e) => e,
69
FirstNonNull(e) => e,
70
Last(e) => e,
71
LastNonNull(e) => e,
72
Item { input, .. } => input,
73
Mean(e) => e,
74
Implode(e) => e,
75
Count { input, .. } => input,
76
Quantile { expr, .. } => expr,
77
Sum(e) => e,
78
AggGroups(e) => e,
79
Std(e, _) => e,
80
Var(e, _) => e,
81
}
82
}
83
}
84
85
/// Expressions that can be used in various contexts.
86
///
87
/// Queries consist of multiple expressions.
88
/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using
89
/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info.
90
#[derive(Clone, PartialEq)]
91
#[must_use]
92
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
93
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
94
pub enum Expr {
95
/// Values in a `eval` context.
96
///
97
/// Equivalent of `pl.element()`.
98
Element,
99
Alias(Arc<Expr>, PlSmallStr),
100
Column(PlSmallStr),
101
Selector(Selector),
102
Literal(LiteralValue),
103
DataTypeFunction(DataTypeFunction),
104
BinaryExpr {
105
left: Arc<Expr>,
106
op: Operator,
107
right: Arc<Expr>,
108
},
109
Cast {
110
expr: Arc<Expr>,
111
dtype: DataTypeExpr,
112
options: CastOptions,
113
},
114
Sort {
115
expr: Arc<Expr>,
116
options: SortOptions,
117
},
118
Gather {
119
expr: Arc<Expr>,
120
idx: Arc<Expr>,
121
returns_scalar: bool,
122
null_on_oob: bool,
123
},
124
SortBy {
125
expr: Arc<Expr>,
126
by: Vec<Expr>,
127
sort_options: SortMultipleOptions,
128
},
129
Agg(AggExpr),
130
/// A ternary operation
131
/// if true then "foo" else "bar"
132
Ternary {
133
predicate: Arc<Expr>,
134
truthy: Arc<Expr>,
135
falsy: Arc<Expr>,
136
},
137
Function {
138
/// function arguments
139
input: Vec<Expr>,
140
/// function to apply
141
function: FunctionExpr,
142
},
143
Explode {
144
input: Arc<Expr>,
145
options: ExplodeOptions,
146
},
147
Filter {
148
input: Arc<Expr>,
149
by: Arc<Expr>,
150
},
151
/// Polars flavored window functions.
152
Over {
153
/// Also has the input. i.e. avg("foo")
154
function: Arc<Expr>,
155
partition_by: Vec<Expr>,
156
order_by: Option<(Arc<Expr>, SortOptions)>,
157
mapping: WindowMapping,
158
},
159
#[cfg(feature = "dynamic_group_by")]
160
Rolling {
161
function: Arc<Expr>,
162
index_column: Arc<Expr>,
163
period: Duration,
164
offset: Duration,
165
closed_window: ClosedWindow,
166
},
167
Slice {
168
input: Arc<Expr>,
169
/// length is not yet known so we accept negative offsets
170
offset: Arc<Expr>,
171
length: Arc<Expr>,
172
},
173
/// Set root name as Alias
174
KeepName(Arc<Expr>),
175
Len,
176
#[cfg(feature = "dtype-struct")]
177
Field(Arc<[PlSmallStr]>),
178
AnonymousFunction {
179
/// function arguments
180
input: Vec<Expr>,
181
/// function to apply
182
function: OpaqueColumnUdf,
183
184
options: FunctionOptions,
185
/// used for formatting
186
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
187
fmt_str: Box<PlSmallStr>,
188
},
189
/// Evaluates the `evaluation` expression on the output of the `expr`.
190
///
191
/// Consequently, `expr` is an input and `evaluation` is not and needs a different schema.
192
Eval {
193
expr: Arc<Expr>,
194
evaluation: Arc<Expr>,
195
variant: EvalVariant,
196
},
197
/// Evaluates the `evaluation` expressions on the output of the `expr`.
198
///
199
/// Consequently, `expr` is an input and `evaluation` uses an extended schema that includes this input.
200
#[cfg(feature = "dtype-struct")]
201
StructEval {
202
expr: Arc<Expr>,
203
evaluation: Vec<Expr>,
204
},
205
/// SQL SubQueries
206
/// Plan,
207
/// Post-select expression and output-name of that expr
208
SubPlan(SpecialEq<Arc<DslPlan>>, Vec<(PlSmallStr, Expr)>),
209
RenameAlias {
210
function: RenameAliasFn,
211
expr: Arc<Expr>,
212
},
213
/// Not a real expression. This is meant
214
/// as catch-all for IR expressions that
215
/// are not supported by DSL.
216
Display {
217
inputs: Vec<Expr>,
218
fmt_str: Box<PlSmallStr>,
219
},
220
}
221
222
#[derive(Clone)]
223
pub enum LazySerde<T: Clone> {
224
Deserialized(T),
225
Bytes(Bytes),
226
/// Named functions allow for serializing arbitrary Rust functions as long as both sides know
227
/// ahead of time which function it is. There is a registry of functions that both sides know
228
/// and every time we need serialize we serialize the function by name in the registry.
229
///
230
/// Used by cloud.
231
Named {
232
// Name and payload are used by the NamedRegistry
233
// To load the function `T` at runtime.
234
name: String,
235
payload: Option<Bytes>,
236
// Sometimes we need the function `T` before sending
237
// to a different machine, so optionally set it as well.
238
value: Option<T>,
239
},
240
}
241
242
impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
243
fn eq(&self, other: &Self) -> bool {
244
use LazySerde as L;
245
match (self, other) {
246
(L::Deserialized(a), L::Deserialized(b)) => a == b,
247
(L::Bytes(a), L::Bytes(b)) => {
248
std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len()
249
},
250
(
251
L::Named {
252
name: l,
253
payload: pl,
254
value: _,
255
},
256
L::Named {
257
name: r,
258
payload: pr,
259
value: _,
260
},
261
) => {
262
#[cfg(debug_assertions)]
263
{
264
if l == r {
265
assert_eq!(pl, pr, "name should point to unique payload")
266
}
267
}
268
_ = pl;
269
_ = pr;
270
l == r
271
},
272
_ => false,
273
}
274
}
275
}
276
277
impl<T: Clone> Debug for LazySerde<T> {
278
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279
match self {
280
Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
281
Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
282
Self::Named {
283
name,
284
payload: _,
285
value: _,
286
} => write!(f, "lazy-serde<Named>: {name}"),
287
}
288
}
289
}
290
291
#[allow(clippy::derived_hash_with_manual_eq)]
292
impl Hash for Expr {
293
fn hash<H: Hasher>(&self, state: &mut H) {
294
let d = std::mem::discriminant(self);
295
d.hash(state);
296
match self {
297
Expr::Column(name) => name.hash(state),
298
// Expr::Columns(names) => names.hash(state),
299
// Expr::DtypeColumn(dtypes) => dtypes.hash(state),
300
// Expr::IndexColumn(indices) => indices.hash(state),
301
Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
302
Expr::Selector(s) => s.hash(state),
303
// Expr::Nth(v) => v.hash(state),
304
Expr::DataTypeFunction(v) => v.hash(state),
305
Expr::Filter { input, by } => {
306
input.hash(state);
307
by.hash(state);
308
},
309
Expr::BinaryExpr { left, op, right } => {
310
left.hash(state);
311
right.hash(state);
312
std::mem::discriminant(op).hash(state)
313
},
314
Expr::Cast {
315
expr,
316
dtype,
317
options: strict,
318
} => {
319
expr.hash(state);
320
dtype.hash(state);
321
strict.hash(state)
322
},
323
Expr::Sort { expr, options } => {
324
expr.hash(state);
325
options.hash(state);
326
},
327
Expr::Alias(input, name) => {
328
input.hash(state);
329
name.hash(state)
330
},
331
Expr::KeepName(input) => input.hash(state),
332
Expr::Ternary {
333
predicate,
334
truthy,
335
falsy,
336
} => {
337
predicate.hash(state);
338
truthy.hash(state);
339
falsy.hash(state);
340
},
341
Expr::Function { input, function } => {
342
input.hash(state);
343
std::mem::discriminant(function).hash(state);
344
},
345
Expr::Gather {
346
expr,
347
idx,
348
returns_scalar,
349
null_on_oob,
350
} => {
351
expr.hash(state);
352
idx.hash(state);
353
returns_scalar.hash(state);
354
null_on_oob.hash(state);
355
},
356
// already hashed by discriminant
357
Expr::Element | Expr::Len => {},
358
Expr::SortBy {
359
expr,
360
by,
361
sort_options,
362
} => {
363
expr.hash(state);
364
by.hash(state);
365
sort_options.hash(state);
366
},
367
Expr::Agg(input) => input.hash(state),
368
Expr::Explode { input, options } => {
369
options.hash(state);
370
input.hash(state)
371
},
372
#[cfg(feature = "dynamic_group_by")]
373
Expr::Rolling {
374
function,
375
index_column,
376
period,
377
offset,
378
closed_window,
379
} => {
380
function.hash(state);
381
index_column.hash(state);
382
period.hash(state);
383
offset.hash(state);
384
closed_window.hash(state);
385
},
386
Expr::Over {
387
function,
388
partition_by,
389
order_by,
390
mapping,
391
} => {
392
function.hash(state);
393
partition_by.hash(state);
394
order_by.hash(state);
395
mapping.hash(state);
396
},
397
Expr::Slice {
398
input,
399
offset,
400
length,
401
} => {
402
input.hash(state);
403
offset.hash(state);
404
length.hash(state);
405
},
406
Expr::RenameAlias { function, expr } => {
407
function.hash(state);
408
expr.hash(state);
409
},
410
Expr::Display { inputs, fmt_str } => {
411
inputs.hash(state);
412
fmt_str.hash(state);
413
},
414
Expr::AnonymousFunction {
415
input,
416
function: _,
417
options,
418
fmt_str,
419
} => {
420
input.hash(state);
421
options.hash(state);
422
fmt_str.hash(state);
423
},
424
Expr::Eval {
425
expr: input,
426
evaluation,
427
variant,
428
} => {
429
input.hash(state);
430
evaluation.hash(state);
431
variant.hash(state);
432
},
433
#[cfg(feature = "dtype-struct")]
434
Expr::StructEval {
435
expr: input,
436
evaluation,
437
} => {
438
input.hash(state);
439
evaluation.hash(state);
440
},
441
Expr::SubPlan(_, names) => names.hash(state),
442
#[cfg(feature = "dtype-struct")]
443
Expr::Field(names) => names.hash(state),
444
}
445
}
446
}
447
448
impl Eq for Expr {}
449
450
impl Default for Expr {
451
fn default() -> Self {
452
Expr::Literal(LiteralValue::Scalar(Scalar::default()))
453
}
454
}
455
456
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
457
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
458
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
459
pub enum Excluded {
460
Name(PlSmallStr),
461
Dtype(DataType),
462
}
463
464
impl Expr {
465
/// Get Field result of the expression. The schema is the input data.
466
pub fn to_field(&self, schema: &Schema) -> PolarsResult<Field> {
467
// this is not called much and the expression depth is typically shallow
468
let mut arena = Arena::with_capacity(5);
469
self.to_field_amortized(schema, &mut arena)
470
}
471
pub(crate) fn to_field_amortized(
472
&self,
473
schema: &Schema,
474
expr_arena: &mut Arena<AExpr>,
475
) -> PolarsResult<Field> {
476
let mut ctx = ExprToIRContext::new_with_fields(expr_arena, schema);
477
ctx.allow_unknown = true;
478
let expr = to_expr_ir(self.clone(), &mut ctx)?;
479
let (node, output_name) = expr.into_inner();
480
let dtype = expr_arena
481
.get(node)
482
.to_dtype(&ToFieldContext::new(expr_arena, schema))?;
483
Ok(Field::new(output_name.into_inner().unwrap(), dtype))
484
}
485
486
pub fn into_selector(self) -> Option<Selector> {
487
match self {
488
Expr::Column(name) => Some(Selector::ByName {
489
names: [name].into(),
490
strict: true,
491
}),
492
Expr::Selector(selector) => Some(selector),
493
_ => None,
494
}
495
}
496
497
pub fn try_into_selector(self) -> PolarsResult<Selector> {
498
match self {
499
Expr::Column(name) => Ok(Selector::ByName {
500
names: [name].into(),
501
strict: true,
502
}),
503
Expr::Selector(selector) => Ok(selector),
504
expr => Err(polars_err!(InvalidOperation: "cannot turn `{expr}` into selector")),
505
}
506
}
507
508
/// Extract a constant usize from an expression.
509
pub fn extract_usize(&self) -> PolarsResult<usize> {
510
match self {
511
Expr::Literal(n) => n.extract_usize(),
512
Expr::Cast { expr, dtype, .. } => {
513
// lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal.
514
if dtype.as_literal().is_some_and(|dt| dt.is_integer()) {
515
expr.extract_usize()
516
} else {
517
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
518
}
519
},
520
_ => {
521
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
522
},
523
}
524
}
525
526
pub fn extract_i64(&self) -> PolarsResult<i64> {
527
match self {
528
Expr::Literal(n) => n.extract_i64(),
529
Expr::BinaryExpr { left, op, right } => match op {
530
Operator::Minus => {
531
let left = left.extract_i64()?;
532
let right = right.extract_i64()?;
533
Ok(left - right)
534
},
535
_ => unreachable!(),
536
},
537
Expr::Cast { expr, dtype, .. } => {
538
if dtype.as_literal().is_some_and(|dt| dt.is_integer()) {
539
expr.extract_i64()
540
} else {
541
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
542
}
543
},
544
_ => {
545
polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
546
},
547
}
548
}
549
550
#[inline]
551
pub fn map_unary(self, function: impl Into<FunctionExpr>) -> Self {
552
Expr::n_ary(function, vec![self])
553
}
554
#[inline]
555
pub fn map_binary(self, function: impl Into<FunctionExpr>, rhs: Self) -> Self {
556
Expr::n_ary(function, vec![self, rhs])
557
}
558
559
#[inline]
560
pub fn map_ternary(self, function: impl Into<FunctionExpr>, arg1: Expr, arg2: Expr) -> Expr {
561
Expr::n_ary(function, vec![self, arg1, arg2])
562
}
563
564
#[inline]
565
pub fn try_map_n_ary(
566
self,
567
function: impl Into<FunctionExpr>,
568
exprs: impl IntoIterator<Item = PolarsResult<Expr>>,
569
) -> PolarsResult<Expr> {
570
let exprs = exprs.into_iter();
571
let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
572
input.push(self);
573
for e in exprs {
574
input.push(e?);
575
}
576
Ok(Expr::n_ary(function, input))
577
}
578
579
#[inline]
580
pub fn map_n_ary(
581
self,
582
function: impl Into<FunctionExpr>,
583
exprs: impl IntoIterator<Item = Expr>,
584
) -> Expr {
585
let exprs = exprs.into_iter();
586
let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
587
input.push(self);
588
input.extend(exprs);
589
Expr::n_ary(function, input)
590
}
591
592
#[inline]
593
pub fn n_ary(function: impl Into<FunctionExpr>, input: Vec<Expr>) -> Expr {
594
let function = function.into();
595
Expr::Function { input, function }
596
}
597
}
598
599
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
600
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
601
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
602
pub enum EvalVariant {
603
/// `list.eval`
604
List,
605
/// `list.agg`
606
ListAgg,
607
608
/// `array.eval`
609
Array {
610
/// If set to true, evaluation can output variable amount of items and output datatype will
611
/// be `List`.
612
as_list: bool,
613
},
614
/// `arr.agg`
615
ArrayAgg,
616
617
/// `cumulative_eval`
618
Cumulative { min_samples: usize },
619
}
620
621
impl EvalVariant {
622
pub fn to_name(&self) -> &'static str {
623
match self {
624
Self::List => "list.eval",
625
Self::ListAgg => "list.agg",
626
Self::Array { .. } => "array.eval",
627
Self::ArrayAgg => "array.agg",
628
Self::Cumulative { min_samples: _ } => "cumulative_eval",
629
}
630
}
631
632
/// Get the `DataType` of the `pl.element()` value.
633
pub fn element_dtype<'a>(&self, dtype: &'a DataType) -> PolarsResult<&'a DataType> {
634
match (self, dtype) {
635
(Self::List | Self::ListAgg, DataType::List(inner)) => Ok(inner.as_ref()),
636
#[cfg(feature = "dtype-array")]
637
(Self::Array { .. } | Self::ArrayAgg, DataType::Array(inner, _)) => Ok(inner.as_ref()),
638
(Self::Cumulative { min_samples: _ }, dt) => Ok(dt),
639
_ => polars_bail!(op = self.to_name(), dtype),
640
}
641
}
642
643
/// Get the output datatype from the output element datatype
644
pub fn output_dtype(
645
&self,
646
dtype: &'_ DataType,
647
output_element_dtype: DataType,
648
eval_is_scalar: bool,
649
) -> PolarsResult<DataType> {
650
match (self, dtype) {
651
(Self::List, DataType::List(_)) => Ok(DataType::List(Box::new(output_element_dtype))),
652
(Self::ListAgg, DataType::List(_)) => {
653
if eval_is_scalar {
654
Ok(output_element_dtype)
655
} else {
656
Ok(DataType::List(Box::new(output_element_dtype)))
657
}
658
},
659
#[cfg(feature = "dtype-array")]
660
(Self::Array { as_list: false }, DataType::Array(_, width)) => {
661
Ok(DataType::Array(Box::new(output_element_dtype), *width))
662
},
663
#[cfg(feature = "dtype-array")]
664
(Self::Array { as_list: true }, DataType::Array(_, _)) => {
665
Ok(DataType::List(Box::new(output_element_dtype)))
666
},
667
#[cfg(feature = "dtype-array")]
668
(Self::ArrayAgg, DataType::Array(_, _)) => {
669
if eval_is_scalar {
670
Ok(output_element_dtype)
671
} else {
672
Ok(DataType::List(Box::new(output_element_dtype)))
673
}
674
},
675
(Self::Cumulative { min_samples: _ }, _) => Ok(output_element_dtype),
676
_ => polars_bail!(op = self.to_name(), dtype),
677
}
678
}
679
680
pub fn is_elementwise(&self) -> bool {
681
match self {
682
EvalVariant::List | EvalVariant::ListAgg => true,
683
EvalVariant::Array { .. } | EvalVariant::ArrayAgg => true,
684
EvalVariant::Cumulative { min_samples: _ } => false,
685
}
686
}
687
688
pub fn is_row_separable(&self) -> bool {
689
match self {
690
EvalVariant::List | EvalVariant::ListAgg => true,
691
EvalVariant::Array { .. } | EvalVariant::ArrayAgg => true,
692
EvalVariant::Cumulative { min_samples: _ } => false,
693
}
694
}
695
696
pub fn is_length_preserving(&self) -> bool {
697
match self {
698
EvalVariant::List
699
| EvalVariant::ListAgg
700
| EvalVariant::Array { .. }
701
| EvalVariant::ArrayAgg
702
| EvalVariant::Cumulative { .. } => true,
703
}
704
}
705
}
706
707
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
708
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
709
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
710
pub enum Operator {
711
Eq,
712
EqValidity,
713
NotEq,
714
NotEqValidity,
715
Lt,
716
LtEq,
717
Gt,
718
GtEq,
719
Plus,
720
Minus,
721
Multiply,
722
/// Rust division semantics, this is what Rust interface `/` dispatches to
723
RustDivide,
724
/// Python division semantics, converting to floats. This is what python `/` operator dispatches to
725
TrueDivide,
726
/// Floor division semantics, this is what python `//` dispatches to
727
FloorDivide,
728
Modulus,
729
And,
730
Or,
731
Xor,
732
LogicalAnd,
733
LogicalOr,
734
}
735
736
impl Display for Operator {
737
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
738
use Operator::*;
739
let tkn = match self {
740
Eq => "==",
741
EqValidity => "==v",
742
NotEq => "!=",
743
NotEqValidity => "!=v",
744
Lt => "<",
745
LtEq => "<=",
746
Gt => ">",
747
GtEq => ">=",
748
Plus => "+",
749
Minus => "-",
750
Multiply => "*",
751
RustDivide => "rust_div",
752
TrueDivide => "/",
753
FloorDivide => "//",
754
Modulus => "%",
755
And | LogicalAnd => "&",
756
Or | LogicalOr => "|",
757
Xor => "^",
758
};
759
write!(f, "{tkn}")
760
}
761
}
762
763
impl Operator {
764
pub fn is_comparison(&self) -> bool {
765
matches!(
766
self,
767
Self::Eq
768
| Self::NotEq
769
| Self::Lt
770
| Self::LtEq
771
| Self::Gt
772
| Self::GtEq
773
| Self::EqValidity
774
| Self::NotEqValidity
775
)
776
}
777
778
pub fn is_bitwise(&self) -> bool {
779
matches!(self, Self::And | Self::Or | Self::Xor)
780
}
781
782
pub fn is_comparison_or_bitwise(&self) -> bool {
783
self.is_comparison() || self.is_bitwise()
784
}
785
786
pub fn swap_operands(self) -> Self {
787
match self {
788
Operator::Eq => Operator::Eq,
789
Operator::Gt => Operator::Lt,
790
Operator::GtEq => Operator::LtEq,
791
Operator::LtEq => Operator::GtEq,
792
Operator::Or => Operator::Or,
793
Operator::LogicalAnd => Operator::LogicalAnd,
794
Operator::LogicalOr => Operator::LogicalOr,
795
Operator::Xor => Operator::Xor,
796
Operator::NotEq => Operator::NotEq,
797
Operator::EqValidity => Operator::EqValidity,
798
Operator::NotEqValidity => Operator::NotEqValidity,
799
// Operator::Divide requires modifying the right operand: left / right == 1/right * left
800
Operator::RustDivide => unimplemented!(),
801
Operator::Multiply => Operator::Multiply,
802
Operator::And => Operator::And,
803
Operator::Plus => Operator::Plus,
804
// Operator::Minus requires modifying the right operand: left - right == -right + left
805
Operator::Minus => unimplemented!(),
806
Operator::Lt => Operator::Gt,
807
_ => unimplemented!(),
808
}
809
}
810
811
pub fn is_arithmetic(&self) -> bool {
812
!(self.is_comparison_or_bitwise())
813
}
814
}
815
816
#[derive(Clone, PartialEq, Eq, Hash)]
817
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
818
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
819
pub enum RenameAliasFn {
820
Prefix(PlSmallStr),
821
Suffix(PlSmallStr),
822
ToLowercase,
823
ToUppercase,
824
Map(PlanCallback<PlSmallStr, PlSmallStr>),
825
Replace {
826
pattern: PlSmallStr,
827
value: PlSmallStr,
828
literal: bool,
829
},
830
}
831
832
impl RenameAliasFn {
833
pub fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
834
let out = match self {
835
Self::Prefix(prefix) => format_pl_smallstr!("{prefix}{name}"),
836
Self::Suffix(suffix) => format_pl_smallstr!("{name}{suffix}"),
837
Self::ToLowercase => PlSmallStr::from_string(name.to_lowercase()),
838
Self::ToUppercase => PlSmallStr::from_string(name.to_uppercase()),
839
Self::Map(f) => f.call(name.clone())?,
840
Self::Replace {
841
pattern,
842
value,
843
literal,
844
} => {
845
if *literal {
846
name.replace(pattern.as_str(), value.as_str()).into()
847
} else {
848
feature_gated!("regex", {
849
let rx = polars_utils::regex_cache::compile_regex(pattern)?;
850
rx.replace_all(name, value.as_str()).into()
851
})
852
}
853
},
854
};
855
Ok(out)
856
}
857
}
858
859
pub type RenameAliasRustFn =
860
dyn Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + 'static + Send + Sync;
861
862