Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-sql/src/functions.rs
8364 views
1
use std::ops::{Add, Sub};
2
3
use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
4
use polars_core::prelude::{
5
DataType, ExplodeOptions, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail,
6
polars_err,
7
};
8
use polars_lazy::dsl::Expr;
9
#[cfg(feature = "rank")]
10
use polars_lazy::prelude::{RankMethod, RankOptions};
11
use polars_ops::chunked_array::UnicodeForm;
12
use polars_ops::series::RoundMode;
13
use polars_plan::dsl::functions::{
14
as_struct, coalesce, col, cols, concat_str, element, int_range, len, lit, max_horizontal,
15
min_horizontal, when,
16
};
17
use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
18
use polars_plan::prelude::StrptimeOptions;
19
use polars_utils::pl_str::PlSmallStr;
20
use sqlparser::ast::helpers::attached_token::AttachedToken;
21
use sqlparser::ast::{
22
DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
23
FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
24
OrderByExpr, Value as SQLValue, ValueWithSpan, WindowFrame, WindowFrameBound, WindowFrameUnits,
25
WindowSpec, WindowType,
26
};
27
use sqlparser::tokenizer::Span;
28
29
use crate::SQLContext;
30
use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
31
32
pub(crate) struct SQLFunctionVisitor<'a> {
33
pub(crate) func: &'a SQLFunction,
34
pub(crate) ctx: &'a mut SQLContext,
35
pub(crate) active_schema: Option<&'a Schema>,
36
}
37
38
/// SQL functions that are supported by Polars
39
pub(crate) enum PolarsSQLFunctions {
40
// ----
41
// Bitwise functions
42
// ----
43
/// SQL 'bit_and' function.
44
/// Returns the bitwise AND of the input expressions.
45
/// ```sql
46
/// SELECT BIT_AND(col1, col2) FROM df;
47
/// ```
48
BitAnd,
49
/// SQL 'bit_count' function.
50
/// Returns the number of set bits in the input expression.
51
/// ```sql
52
/// SELECT BIT_COUNT(col1) FROM df;
53
/// ```
54
#[cfg(feature = "bitwise")]
55
BitCount,
56
/// SQL 'bit_or' function.
57
/// Returns the bitwise OR of the input expressions.
58
/// ```sql
59
/// SELECT BIT_OR(col1, col2) FROM df;
60
/// ```
61
BitNot,
62
/// SQL 'bit_not' function.
63
/// Returns the bitwise Not of the input expression.
64
/// ```sql
65
/// SELECT BIT_Not(col1) FROM df;
66
/// ```
67
BitOr,
68
/// SQL 'bit_xor' function.
69
/// Returns the bitwise XOR of the input expressions.
70
/// ```sql
71
/// SELECT BIT_XOR(col1, col2) FROM df;
72
/// ```
73
BitXor,
74
75
// ----
76
// Math functions
77
// ----
78
/// SQL 'abs' function.
79
/// Returns the absolute value of the input expression.
80
/// ```sql
81
/// SELECT ABS(col1) FROM df;
82
/// ```
83
Abs,
84
/// SQL 'ceil' function.
85
/// Returns the nearest integer closest from zero.
86
/// ```sql
87
/// SELECT CEIL(col1) FROM df;
88
/// ```
89
Ceil,
90
/// SQL 'div' function.
91
/// Returns the integer quotient of the division.
92
/// ```sql
93
/// SELECT DIV(col1, 2) FROM df;
94
/// ```
95
Div,
96
/// SQL 'exp' function.
97
/// Computes the exponential of the given value.
98
/// ```sql
99
/// SELECT EXP(col1) FROM df;
100
/// ```
101
Exp,
102
/// SQL 'floor' function.
103
/// Returns the nearest integer away from zero.
104
/// 0.5 will be rounded
105
/// ```sql
106
/// SELECT FLOOR(col1) FROM df;
107
/// ```
108
Floor,
109
/// SQL 'pi' function.
110
/// Returns a (very good) approximation of 𝜋.
111
/// ```sql
112
/// SELECT PI() FROM df;
113
/// ```
114
Pi,
115
/// SQL 'ln' function.
116
/// Computes the natural logarithm of the given value.
117
/// ```sql
118
/// SELECT LN(col1) FROM df;
119
/// ```
120
Ln,
121
/// SQL 'log2' function.
122
/// Computes the logarithm of the given value in base 2.
123
/// ```sql
124
/// SELECT LOG2(col1) FROM df;
125
/// ```
126
Log2,
127
/// SQL 'log10' function.
128
/// Computes the logarithm of the given value in base 10.
129
/// ```sql
130
/// SELECT LOG10(col1) FROM df;
131
/// ```
132
Log10,
133
/// SQL 'log' function.
134
/// Computes the `base` logarithm of the given value.
135
/// ```sql
136
/// SELECT LOG(col1, 10) FROM df;
137
/// ```
138
Log,
139
/// SQL 'log1p' function.
140
/// Computes the natural logarithm of "given value plus one".
141
/// ```sql
142
/// SELECT LOG1P(col1) FROM df;
143
/// ```
144
Log1p,
145
/// SQL 'pow' function.
146
/// Returns the value to the power of the given exponent.
147
/// ```sql
148
/// SELECT POW(col1, 2) FROM df;
149
/// ```
150
Pow,
151
/// SQL 'mod' function.
152
/// Returns the remainder of a numeric expression divided by another numeric expression.
153
/// ```sql
154
/// SELECT MOD(col1, 2) FROM df;
155
/// ```
156
Mod,
157
/// SQL 'sqrt' function.
158
/// Returns the square root (√) of a number.
159
/// ```sql
160
/// SELECT SQRT(col1) FROM df;
161
/// ```
162
Sqrt,
163
/// SQL 'cbrt' function.
164
/// Returns the cube root (∛) of a number.
165
/// ```sql
166
/// SELECT CBRT(col1) FROM df;
167
/// ```
168
Cbrt,
169
/// SQL 'round' function.
170
/// Round a number to `x` decimals (default: 0) away from zero.
171
/// .5 is rounded away from zero.
172
/// ```sql
173
/// SELECT ROUND(col1, 3) FROM df;
174
/// ```
175
Round,
176
/// SQL 'sign' function.
177
/// Returns the sign of the argument as -1, 0, or +1.
178
/// ```sql
179
/// SELECT SIGN(col1) FROM df;
180
/// ```
181
Sign,
182
183
// ----
184
// Trig functions
185
// ----
186
/// SQL 'cos' function.
187
/// Compute the cosine sine of the input expression (in radians).
188
/// ```sql
189
/// SELECT COS(col1) FROM df;
190
/// ```
191
Cos,
192
/// SQL 'cot' function.
193
/// Compute the cotangent of the input expression (in radians).
194
/// ```sql
195
/// SELECT COT(col1) FROM df;
196
/// ```
197
Cot,
198
/// SQL 'sin' function.
199
/// Compute the sine of the input expression (in radians).
200
/// ```sql
201
/// SELECT SIN(col1) FROM df;
202
/// ```
203
Sin,
204
/// SQL 'tan' function.
205
/// Compute the tangent of the input expression (in radians).
206
/// ```sql
207
/// SELECT TAN(col1) FROM df;
208
/// ```
209
Tan,
210
/// SQL 'cosd' function.
211
/// Compute the cosine sine of the input expression (in degrees).
212
/// ```sql
213
/// SELECT COSD(col1) FROM df;
214
/// ```
215
CosD,
216
/// SQL 'cotd' function.
217
/// Compute cotangent of the input expression (in degrees).
218
/// ```sql
219
/// SELECT COTD(col1) FROM df;
220
/// ```
221
CotD,
222
/// SQL 'sind' function.
223
/// Compute the sine of the input expression (in degrees).
224
/// ```sql
225
/// SELECT SIND(col1) FROM df;
226
/// ```
227
SinD,
228
/// SQL 'tand' function.
229
/// Compute the tangent of the input expression (in degrees).
230
/// ```sql
231
/// SELECT TAND(col1) FROM df;
232
/// ```
233
TanD,
234
/// SQL 'acos' function.
235
/// Compute inverse cosine of the input expression (in radians).
236
/// ```sql
237
/// SELECT ACOS(col1) FROM df;
238
/// ```
239
Acos,
240
/// SQL 'asin' function.
241
/// Compute inverse sine of the input expression (in radians).
242
/// ```sql
243
/// SELECT ASIN(col1) FROM df;
244
/// ```
245
Asin,
246
/// SQL 'atan' function.
247
/// Compute inverse tangent of the input expression (in radians).
248
/// ```sql
249
/// SELECT ATAN(col1) FROM df;
250
/// ```
251
Atan,
252
/// SQL 'atan2' function.
253
/// Compute the inverse tangent of col1/col2 (in radians).
254
/// ```sql
255
/// SELECT ATAN2(col1, col2) FROM df;
256
/// ```
257
Atan2,
258
/// SQL 'acosd' function.
259
/// Compute inverse cosine of the input expression (in degrees).
260
/// ```sql
261
/// SELECT ACOSD(col1) FROM df;
262
/// ```
263
AcosD,
264
/// SQL 'asind' function.
265
/// Compute inverse sine of the input expression (in degrees).
266
/// ```sql
267
/// SELECT ASIND(col1) FROM df;
268
/// ```
269
AsinD,
270
/// SQL 'atand' function.
271
/// Compute inverse tangent of the input expression (in degrees).
272
/// ```sql
273
/// SELECT ATAND(col1) FROM df;
274
/// ```
275
AtanD,
276
/// SQL 'atan2d' function.
277
/// Compute the inverse tangent of col1/col2 (in degrees).
278
/// ```sql
279
/// SELECT ATAN2D(col1) FROM df;
280
/// ```
281
Atan2D,
282
/// SQL 'degrees' function.
283
/// Convert between radians and degrees.
284
/// ```sql
285
/// SELECT DEGREES(col1) FROM df;
286
/// ```
287
///
288
///
289
Degrees,
290
/// SQL 'RADIANS' function.
291
/// Convert between degrees and radians.
292
/// ```sql
293
/// SELECT RADIANS(col1) FROM df;
294
/// ```
295
Radians,
296
297
// ----
298
// Temporal functions
299
// ----
300
/// SQL 'date_part' function.
301
/// Extracts a part of a date (or datetime) such as 'year', 'month', etc.
302
/// ```sql
303
/// SELECT DATE_PART('year', col1) FROM df;
304
/// SELECT DATE_PART('day', col1) FROM df;
305
DatePart,
306
/// SQL 'strftime' function.
307
/// Converts a datetime to a string using a format string.
308
/// ```sql
309
/// SELECT STRFTIME(col1, '%d-%m-%Y %H:%M') FROM df;
310
/// ```
311
Strftime,
312
313
// ----
314
// String functions
315
// ----
316
/// SQL 'bit_length' function (bytes).
317
/// ```sql
318
/// SELECT BIT_LENGTH(col1) FROM df;
319
/// ```
320
BitLength,
321
/// SQL 'concat' function.
322
/// Returns all input expressions concatenated together as a string.
323
/// ```sql
324
/// SELECT CONCAT(col1, col2) FROM df;
325
/// ```
326
Concat,
327
/// SQL 'concat_ws' function.
328
/// Returns all input expressions concatenated together
329
/// (and interleaved with a separator) as a string.
330
/// ```sql
331
/// SELECT CONCAT_WS(':', col1, col2, col3) FROM df;
332
/// ```
333
ConcatWS,
334
/// SQL 'date' function.
335
/// Converts a formatted string date to an actual Date type; ISO-8601 format is assumed
336
/// unless a strftime-compatible formatting string is provided as the second parameter.
337
/// ```sql
338
/// SELECT DATE('2021-03-15') FROM df;
339
/// SELECT DATE('2021-15-03', '%Y-d%-%m') FROM df;
340
/// SELECT DATE('2021-03', '%Y-%m') FROM df;
341
/// ```
342
Date,
343
/// SQL 'ends_with' function.
344
/// Returns True if the value ends with the second argument.
345
/// ```sql
346
/// SELECT ENDS_WITH(col1, 'a') FROM df;
347
/// SELECT col2 from df WHERE ENDS_WITH(col1, 'a');
348
/// ```
349
EndsWith,
350
/// SQL 'initcap' function.
351
/// Returns the value with the first letter capitalized.
352
/// ```sql
353
/// SELECT INITCAP(col1) FROM df;
354
/// ```
355
#[cfg(feature = "nightly")]
356
InitCap,
357
/// SQL 'left' function.
358
/// Returns the first (leftmost) `n` characters.
359
/// ```sql
360
/// SELECT LEFT(col1, 3) FROM df;
361
/// ```
362
Left,
363
/// SQL 'length' function (characters.
364
/// Returns the character length of the string.
365
/// ```sql
366
/// SELECT LENGTH(col1) FROM df;
367
/// ```
368
Length,
369
/// SQL 'lower' function.
370
/// Returns an lowercased column.
371
/// ```sql
372
/// SELECT LOWER(col1) FROM df;
373
/// ```
374
Lower,
375
/// SQL 'ltrim' function.
376
/// Strip whitespaces from the left.
377
/// ```sql
378
/// SELECT LTRIM(col1) FROM df;
379
/// ```
380
LTrim,
381
/// SQL 'normalize' function.
382
/// Convert string to Unicode normalization form
383
/// (one of NFC, NFKC, NFD, or NFKD - unquoted).
384
/// ```sql
385
/// SELECT NORMALIZE(col1, NFC) FROM df;
386
/// ```
387
Normalize,
388
/// SQL 'octet_length' function.
389
/// Returns the length of a given string in bytes.
390
/// ```sql
391
/// SELECT OCTET_LENGTH(col1) FROM df;
392
/// ```
393
OctetLength,
394
/// SQL 'regexp_like' function.
395
/// True if `pattern` matches the value (optional: `flags`).
396
/// ```sql
397
/// SELECT REGEXP_LIKE(col1, 'xyz', 'i') FROM df;
398
/// ```
399
RegexpLike,
400
/// SQL 'replace' function.
401
/// Replace a given substring with another string.
402
/// ```sql
403
/// SELECT REPLACE(col1, 'old', 'new') FROM df;
404
/// ```
405
Replace,
406
/// SQL 'reverse' function.
407
/// Return the reversed string.
408
/// ```sql
409
/// SELECT REVERSE(col1) FROM df;
410
/// ```
411
Reverse,
412
/// SQL 'right' function.
413
/// Returns the last (rightmost) `n` characters.
414
/// ```sql
415
/// SELECT RIGHT(col1, 3) FROM df;
416
/// ```
417
Right,
418
/// SQL 'rtrim' function.
419
/// Strip whitespaces from the right.
420
/// ```sql
421
/// SELECT RTRIM(col1) FROM df;
422
/// ```
423
RTrim,
424
/// SQL 'split_part' function.
425
/// Splits a string into an array of strings using the given delimiter
426
/// and returns the `n`-th part (1-indexed).
427
/// ```sql
428
/// SELECT SPLIT_PART(col1, ',', 2) FROM df;
429
/// ```
430
SplitPart,
431
/// SQL 'starts_with' function.
432
/// Returns True if the value starts with the second argument.
433
/// ```sql
434
/// SELECT STARTS_WITH(col1, 'a') FROM df;
435
/// SELECT col2 from df WHERE STARTS_WITH(col1, 'a');
436
/// ```
437
StartsWith,
438
/// SQL 'strpos' function.
439
/// Returns the index of the given substring in the target string.
440
/// ```sql
441
/// SELECT STRPOS(col1,'xyz') FROM df;
442
/// ```
443
StrPos,
444
/// SQL 'substr' function.
445
/// Returns a portion of the data (first character = 1) in the range.
446
/// \[start, start + length]
447
/// ```sql
448
/// SELECT SUBSTR(col1, 3, 5) FROM df;
449
/// ```
450
Substring,
451
/// SQL 'string_to_array' function.
452
/// Splits a string into an array of strings using the given delimiter.
453
/// ```sql
454
/// SELECT STRING_TO_ARRAY(col1, ',') FROM df;
455
/// ```
456
StringToArray,
457
/// SQL 'strptime' function.
458
/// Converts a string to a datetime using a format string.
459
/// ```sql
460
/// SELECT STRPTIME(col1, '%d-%m-%Y %H:%M') FROM df;
461
/// ```
462
Strptime,
463
/// SQL 'time' function.
464
/// Converts a formatted string time to an actual Time type; ISO-8601 format is
465
/// assumed unless a strftime-compatible formatting string is provided as the second
466
/// parameter.
467
/// ```sql
468
/// SELECT TIME('10:30:45') FROM df;
469
/// SELECT TIME('20.30', '%H.%M') FROM df;
470
/// ```
471
Time,
472
/// SQL 'timestamp' function.
473
/// Converts a formatted string datetime to an actual Datetime type; ISO-8601 format is
474
/// assumed unless a strftime-compatible formatting string is provided as the second
475
/// parameter.
476
/// ```sql
477
/// SELECT TIMESTAMP('2021-03-15 10:30:45') FROM df;
478
/// SELECT TIMESTAMP('2021-15-03T00:01:02.333', '%Y-d%-%m %H:%M:%S') FROM df;
479
/// ```
480
Timestamp,
481
/// SQL 'upper' function.
482
/// Returns an uppercased column.
483
/// ```sql
484
/// SELECT UPPER(col1) FROM df;
485
/// ```
486
Upper,
487
488
// ----
489
// Conditional functions
490
// ----
491
/// SQL 'coalesce' function.
492
/// Returns the first non-null value in the provided values/columns.
493
/// ```sql
494
/// SELECT COALESCE(col1, ...) FROM df;
495
/// ```
496
Coalesce,
497
/// SQL 'greatest' function.
498
/// Returns the greatest value in the list of expressions.
499
/// ```sql
500
/// SELECT GREATEST(col1, col2, ...) FROM df;
501
/// ```
502
Greatest,
503
/// SQL 'if' function.
504
/// Returns expr1 if the boolean condition provided as the first
505
/// parameter evaluates to true, and expr2 otherwise.
506
/// ```sql
507
/// SELECT IF(column < 0, expr1, expr2) FROM df;
508
/// ```
509
If,
510
/// SQL 'ifnull' function.
511
/// If an expression value is NULL, return an alternative value.
512
/// ```sql
513
/// SELECT IFNULL(string_col, 'n/a') FROM df;
514
/// ```
515
IfNull,
516
/// SQL 'least' function.
517
/// Returns the smallest value in the list of expressions.
518
/// ```sql
519
/// SELECT LEAST(col1, col2, ...) FROM df;
520
/// ```
521
Least,
522
/// SQL 'nullif' function.
523
/// Returns NULL if two expressions are equal, otherwise returns the first.
524
/// ```sql
525
/// SELECT NULLIF(col1, col2) FROM df;
526
/// ```
527
NullIf,
528
529
// ----
530
// Aggregate functions
531
// ----
532
/// SQL 'avg' function.
533
/// Returns the average (mean) of all the elements in the grouping.
534
/// ```sql
535
/// SELECT AVG(col1) FROM df;
536
/// ```
537
Avg,
538
/// SQL 'corr' function.
539
/// Returns the Pearson correlation coefficient between two columns.
540
/// ```sql
541
/// SELECT CORR(col1, col2) FROM df;
542
/// ```
543
Corr,
544
/// SQL 'count' function.
545
/// Returns the amount of elements in the grouping.
546
/// ```sql
547
/// SELECT COUNT(col1) FROM df;
548
/// SELECT COUNT(*) FROM df;
549
/// SELECT COUNT(DISTINCT col1) FROM df;
550
/// SELECT COUNT(DISTINCT *) FROM df;
551
/// ```
552
Count,
553
/// SQL 'covar_pop' function.
554
/// Returns the population covariance between two columns.
555
/// ```sql
556
/// SELECT COVAR_POP(col1, col2) FROM df;
557
/// ```
558
CovarPop,
559
/// SQL 'covar_samp' function.
560
/// Returns the sample covariance between two columns.
561
/// ```sql
562
/// SELECT COVAR_SAMP(col1, col2) FROM df;
563
/// ```
564
CovarSamp,
565
/// SQL 'first' function.
566
/// Returns the first element of the grouping.
567
/// ```sql
568
/// SELECT FIRST(col1) FROM df;
569
/// ```
570
First,
571
/// SQL 'last' function.
572
/// Returns the last element of the grouping.
573
/// ```sql
574
/// SELECT LAST(col1) FROM df;
575
/// ```
576
Last,
577
/// SQL 'max' function.
578
/// Returns the greatest (maximum) of all the elements in the grouping.
579
/// ```sql
580
/// SELECT MAX(col1) FROM df;
581
/// ```
582
Max,
583
/// SQL 'median' function.
584
/// Returns the median element from the grouping.
585
/// ```sql
586
/// SELECT MEDIAN(col1) FROM df;
587
/// ```
588
Median,
589
/// SQL 'quantile_cont' function.
590
/// Returns the continuous quantile element from the grouping
591
/// (interpolated value between two closest values).
592
/// ```sql
593
/// SELECT QUANTILE_CONT(col1) FROM df;
594
/// ```
595
QuantileCont,
596
/// SQL 'quantile_disc' function.
597
/// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value,
598
/// and returns the value associated with the subinterval where the quantile value falls.
599
/// ```sql
600
/// SELECT QUANTILE_DISC(col1) FROM df;
601
/// ```
602
QuantileDisc,
603
/// SQL 'min' function.
604
/// Returns the smallest (minimum) of all the elements in the grouping.
605
/// ```sql
606
/// SELECT MIN(col1) FROM df;
607
/// ```
608
Min,
609
/// SQL 'stddev' function.
610
/// Returns the standard deviation of all the elements in the grouping.
611
/// ```sql
612
/// SELECT STDDEV(col1) FROM df;
613
/// ```
614
StdDev,
615
/// SQL 'sum' function.
616
/// Returns the sum of all the elements in the grouping.
617
/// ```sql
618
/// SELECT SUM(col1) FROM df;
619
/// ```
620
Sum,
621
/// SQL 'variance' function.
622
/// Returns the variance of all the elements in the grouping.
623
/// ```sql
624
/// SELECT VARIANCE(col1) FROM df;
625
/// ```
626
Variance,
627
628
// ----
629
// Array functions
630
// ----
631
/// SQL 'array_length' function.
632
/// Returns the length of the array.
633
/// ```sql
634
/// SELECT ARRAY_LENGTH(col1) FROM df;
635
/// ```
636
ArrayLength,
637
/// SQL 'array_lower' function.
638
/// Returns the minimum value in an array; equivalent to `array_min`.
639
/// ```sql
640
/// SELECT ARRAY_LOWER(col1) FROM df;
641
/// ```
642
ArrayMin,
643
/// SQL 'array_upper' function.
644
/// Returns the maximum value in an array; equivalent to `array_max`.
645
/// ```sql
646
/// SELECT ARRAY_UPPER(col1) FROM df;
647
/// ```
648
ArrayMax,
649
/// SQL 'array_sum' function.
650
/// Returns the sum of all values in an array.
651
/// ```sql
652
/// SELECT ARRAY_SUM(col1) FROM df;
653
/// ```
654
ArraySum,
655
/// SQL 'array_mean' function.
656
/// Returns the mean of all values in an array.
657
/// ```sql
658
/// SELECT ARRAY_MEAN(col1) FROM df;
659
/// ```
660
ArrayMean,
661
/// SQL 'array_reverse' function.
662
/// Returns the array with the elements in reverse order.
663
/// ```sql
664
/// SELECT ARRAY_REVERSE(col1) FROM df;
665
/// ```
666
ArrayReverse,
667
/// SQL 'array_unique' function.
668
/// Returns the array with the unique elements.
669
/// ```sql
670
/// SELECT ARRAY_UNIQUE(col1) FROM df;
671
/// ```
672
ArrayUnique,
673
/// SQL 'unnest' function.
674
/// Unnest/explodes an array column into multiple rows.
675
/// ```sql
676
/// SELECT unnest(col1) FROM df;
677
/// ```
678
Explode,
679
/// SQL 'array_agg' function.
680
/// Concatenates the input expressions, including nulls, into an array.
681
/// ```sql
682
/// SELECT ARRAY_AGG(col1, col2, ...) FROM df;
683
/// ```
684
ArrayAgg,
685
/// SQL 'array_to_string' function.
686
/// Takes all elements of the array and joins them into one string.
687
/// ```sql
688
/// SELECT ARRAY_TO_STRING(col1, ',') FROM df;
689
/// SELECT ARRAY_TO_STRING(col1, ',', 'n/a') FROM df;
690
/// ```
691
ArrayToString,
692
/// SQL 'array_get' function.
693
/// Returns the value at the given index in the array.
694
/// ```sql
695
/// SELECT ARRAY_GET(col1, 1) FROM df;
696
/// ```
697
ArrayGet,
698
/// SQL 'array_contains' function.
699
/// Returns true if the array contains the value.
700
/// ```sql
701
/// SELECT ARRAY_CONTAINS(col1, 'foo') FROM df;
702
/// ```
703
ArrayContains,
704
705
// ----
706
// Window functions
707
// ----
708
/// SQL 'first_value' window function.
709
/// Returns the first value in an ordered set of values (respecting window frame).
710
/// ```sql
711
/// SELECT FIRST_VALUE(col1) OVER (PARTITION BY category ORDER BY id) FROM df;
712
/// ```
713
FirstValue,
714
/// SQL 'last_value' window function.
715
/// Returns the last value in an ordered set of values (respecting window frame).
716
/// With default frame, returns the current row's value.
717
/// ```sql
718
/// SELECT LAST_VALUE(col1) OVER (PARTITION BY category ORDER BY id) FROM df;
719
/// ```
720
LastValue,
721
/// SQL 'lag' function.
722
/// Returns the value of the expression evaluated at the row n rows before the current row.
723
/// ```sql
724
/// SELECT lag(column_1, 1) OVER (PARTITION BY column_2 ORDER BY column_3) FROM df;
725
/// ```
726
Lag,
727
/// SQL 'lead' function.
728
/// Returns the value of the expression evaluated at the row n rows after the current row.
729
/// ```sql
730
/// SELECT lead(column_1, 1) OVER (PARTITION BY column_2 ORDER BY column_3) FROM df;
731
/// ```
732
Lead,
733
/// SQL 'row_number' function.
734
/// Returns the sequential row number within a window partition, starting from 1.
735
/// ```sql
736
/// SELECT ROW_NUMBER() OVER (ORDER BY col1) FROM df;
737
/// SELECT ROW_NUMBER() OVER (PARTITION BY col1 ORDER BY col2) FROM df;
738
/// ```
739
RowNumber,
740
/// SQL 'rank' function.
741
/// Returns the rank of each row within a window partition, with gaps for ties.
742
/// Rows with equal values receive the same rank, and the next rank skips numbers.
743
/// ```sql
744
/// SELECT RANK() OVER (ORDER BY col1) FROM df;
745
/// SELECT RANK() OVER (PARTITION BY col1 ORDER BY col2 DESC) FROM df;
746
/// ```
747
#[cfg(feature = "rank")]
748
Rank,
749
/// SQL 'dense_rank' function.
750
/// Returns the rank of each row within a window partition, without gaps for ties.
751
/// Rows with equal values receive the same rank, and the next rank is consecutive.
752
/// ```sql
753
/// SELECT DENSE_RANK() OVER (ORDER BY col1) FROM df;
754
/// SELECT DENSE_RANK() OVER (PARTITION BY col1 ORDER BY col2 DESC) FROM df;
755
/// ```
756
#[cfg(feature = "rank")]
757
DenseRank,
758
759
// ----
760
// Column selection
761
// ----
762
Columns,
763
764
// ----
765
// User-defined
766
// ----
767
Udf(String),
768
}
769
770
impl PolarsSQLFunctions {
771
pub(crate) fn keywords() -> &'static [&'static str] {
772
&[
773
"abs",
774
"acos",
775
"acosd",
776
"array_contains",
777
"array_get",
778
"array_length",
779
"array_lower",
780
"array_mean",
781
"array_reverse",
782
"array_sum",
783
"array_to_string",
784
"array_unique",
785
"array_upper",
786
"asin",
787
"asind",
788
"atan",
789
"atan2",
790
"atan2d",
791
"atand",
792
"avg",
793
"bit_and",
794
"bit_count",
795
"bit_length",
796
"bit_or",
797
"bit_xor",
798
"cbrt",
799
"ceil",
800
"ceiling",
801
"char_length",
802
"character_length",
803
"coalesce",
804
"columns",
805
"concat",
806
"concat_ws",
807
"corr",
808
"cos",
809
"cosd",
810
"cot",
811
"cotd",
812
"count",
813
"covar",
814
"covar_pop",
815
"covar_samp",
816
"date",
817
"date_part",
818
"degrees",
819
"dense_rank",
820
"ends_with",
821
"exp",
822
"first",
823
"first_value",
824
"floor",
825
"greatest",
826
"if",
827
"ifnull",
828
"initcap",
829
"lag",
830
"last",
831
"last_value",
832
"lead",
833
"least",
834
"left",
835
"length",
836
"ln",
837
"log",
838
"log10",
839
"log1p",
840
"log2",
841
"lower",
842
"ltrim",
843
"max",
844
"median",
845
"quantile_disc",
846
"min",
847
"mod",
848
"nullif",
849
"octet_length",
850
"pi",
851
"pow",
852
"power",
853
"quantile_cont",
854
"quantile_disc",
855
"radians",
856
"rank",
857
"regexp_like",
858
"replace",
859
"reverse",
860
"right",
861
"round",
862
"row_number",
863
"rtrim",
864
"sign",
865
"sin",
866
"sind",
867
"sqrt",
868
"starts_with",
869
"stddev",
870
"stddev_samp",
871
"stdev",
872
"stdev_samp",
873
"strftime",
874
"strpos",
875
"strptime",
876
"substr",
877
"sum",
878
"tan",
879
"tand",
880
"unnest",
881
"upper",
882
"var",
883
"var_samp",
884
"variance",
885
]
886
}
887
}
888
889
impl PolarsSQLFunctions {
890
fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
891
let function_name = function.name.0[0].as_ident().unwrap().value.to_lowercase();
892
Ok(match function_name.as_str() {
893
// ----
894
// Bitwise functions
895
// ----
896
"bit_and" | "bitand" => Self::BitAnd,
897
#[cfg(feature = "bitwise")]
898
"bit_count" | "bitcount" => Self::BitCount,
899
"bit_not" | "bitnot" => Self::BitNot,
900
"bit_or" | "bitor" => Self::BitOr,
901
"bit_xor" | "bitxor" | "xor" => Self::BitXor,
902
903
// ----
904
// Math functions
905
// ----
906
"abs" => Self::Abs,
907
"cbrt" => Self::Cbrt,
908
"ceil" | "ceiling" => Self::Ceil,
909
"div" => Self::Div,
910
"exp" => Self::Exp,
911
"floor" => Self::Floor,
912
"ln" => Self::Ln,
913
"log" => Self::Log,
914
"log10" => Self::Log10,
915
"log1p" => Self::Log1p,
916
"log2" => Self::Log2,
917
"mod" => Self::Mod,
918
"pi" => Self::Pi,
919
"pow" | "power" => Self::Pow,
920
"round" => Self::Round,
921
"sign" => Self::Sign,
922
"sqrt" => Self::Sqrt,
923
924
// ----
925
// Trig functions
926
// ----
927
"cos" => Self::Cos,
928
"cot" => Self::Cot,
929
"sin" => Self::Sin,
930
"tan" => Self::Tan,
931
"cosd" => Self::CosD,
932
"cotd" => Self::CotD,
933
"sind" => Self::SinD,
934
"tand" => Self::TanD,
935
"acos" => Self::Acos,
936
"asin" => Self::Asin,
937
"atan" => Self::Atan,
938
"atan2" => Self::Atan2,
939
"acosd" => Self::AcosD,
940
"asind" => Self::AsinD,
941
"atand" => Self::AtanD,
942
"atan2d" => Self::Atan2D,
943
"degrees" => Self::Degrees,
944
"radians" => Self::Radians,
945
946
// ----
947
// Conditional functions
948
// ----
949
"coalesce" => Self::Coalesce,
950
"greatest" => Self::Greatest,
951
"if" => Self::If,
952
"ifnull" => Self::IfNull,
953
"least" => Self::Least,
954
"nullif" => Self::NullIf,
955
956
// ----
957
// Date functions
958
// ----
959
"date_part" => Self::DatePart,
960
"strftime" => Self::Strftime,
961
962
// ----
963
// String functions
964
// ----
965
"bit_length" => Self::BitLength,
966
"concat" => Self::Concat,
967
"concat_ws" => Self::ConcatWS,
968
"date" => Self::Date,
969
"timestamp" | "datetime" => Self::Timestamp,
970
"ends_with" => Self::EndsWith,
971
#[cfg(feature = "nightly")]
972
"initcap" => Self::InitCap,
973
"length" | "char_length" | "character_length" => Self::Length,
974
"left" => Self::Left,
975
"lower" => Self::Lower,
976
"ltrim" => Self::LTrim,
977
"normalize" => Self::Normalize,
978
"octet_length" => Self::OctetLength,
979
"strpos" => Self::StrPos,
980
"regexp_like" => Self::RegexpLike,
981
"replace" => Self::Replace,
982
"reverse" => Self::Reverse,
983
"right" => Self::Right,
984
"rtrim" => Self::RTrim,
985
"split_part" => Self::SplitPart,
986
"starts_with" => Self::StartsWith,
987
"string_to_array" => Self::StringToArray,
988
"strptime" => Self::Strptime,
989
"substr" => Self::Substring,
990
"time" => Self::Time,
991
"upper" => Self::Upper,
992
993
// ----
994
// Aggregate functions
995
// ----
996
"avg" => Self::Avg,
997
"corr" => Self::Corr,
998
"count" => Self::Count,
999
"covar_pop" => Self::CovarPop,
1000
"covar" | "covar_samp" => Self::CovarSamp,
1001
"first" => Self::First,
1002
"last" => Self::Last,
1003
"max" => Self::Max,
1004
"median" => Self::Median,
1005
"quantile_cont" => Self::QuantileCont,
1006
"quantile_disc" => Self::QuantileDisc,
1007
"min" => Self::Min,
1008
"stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
1009
"sum" => Self::Sum,
1010
"var" | "variance" | "var_samp" => Self::Variance,
1011
1012
// ----
1013
// Array functions
1014
// ----
1015
"array_agg" => Self::ArrayAgg,
1016
"array_contains" => Self::ArrayContains,
1017
"array_get" => Self::ArrayGet,
1018
"array_length" => Self::ArrayLength,
1019
"array_lower" => Self::ArrayMin,
1020
"array_mean" => Self::ArrayMean,
1021
"array_reverse" => Self::ArrayReverse,
1022
"array_sum" => Self::ArraySum,
1023
"array_to_string" => Self::ArrayToString,
1024
"array_unique" => Self::ArrayUnique,
1025
"array_upper" => Self::ArrayMax,
1026
"unnest" => Self::Explode,
1027
1028
// ----
1029
// Window functions
1030
// ----
1031
#[cfg(feature = "rank")]
1032
"dense_rank" => Self::DenseRank,
1033
"first_value" => Self::FirstValue,
1034
"last_value" => Self::LastValue,
1035
"lag" => Self::Lag,
1036
"lead" => Self::Lead,
1037
#[cfg(feature = "rank")]
1038
"rank" => Self::Rank,
1039
"row_number" => Self::RowNumber,
1040
1041
// ----
1042
// Column selection
1043
// ----
1044
"columns" => Self::Columns,
1045
1046
other => {
1047
if ctx.function_registry.contains(other) {
1048
Self::Udf(other.to_string())
1049
} else {
1050
polars_bail!(SQLInterface: "unsupported function '{}'", other);
1051
}
1052
},
1053
})
1054
}
1055
}
1056
1057
impl SQLFunctionVisitor<'_> {
1058
pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
1059
use PolarsSQLFunctions::*;
1060
use polars_lazy::prelude::Literal;
1061
1062
let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
1063
let function = self.func;
1064
1065
// TODO: implement the following functions where possible
1066
if !function.within_group.is_empty() {
1067
polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
1068
}
1069
if function.filter.is_some() {
1070
polars_bail!(SQLInterface: "'FILTER' is not currently supported")
1071
}
1072
if function.null_treatment.is_some() {
1073
polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
1074
}
1075
1076
let log_with_base =
1077
|e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
1078
1079
match function_name {
1080
// ----
1081
// Bitwise functions
1082
// ----
1083
BitAnd => self.visit_binary::<Expr>(Expr::and),
1084
#[cfg(feature = "bitwise")]
1085
BitCount => self.visit_unary(Expr::bitwise_count_ones),
1086
BitNot => self.visit_unary(Expr::not),
1087
BitOr => self.visit_binary::<Expr>(Expr::or),
1088
BitXor => self.visit_binary::<Expr>(Expr::xor),
1089
1090
// ----
1091
// Math functions
1092
// ----
1093
Abs => self.visit_unary(Expr::abs),
1094
Cbrt => self.visit_unary(Expr::cbrt),
1095
Ceil => self.visit_unary(Expr::ceil),
1096
Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
1097
Exp => self.visit_unary(Expr::exp),
1098
Floor => self.visit_unary(Expr::floor),
1099
Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
1100
Log => self.visit_binary(Expr::log),
1101
Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
1102
Log1p => self.visit_unary(Expr::log1p),
1103
Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
1104
Pi => self.visit_nullary(Expr::pi),
1105
Mod => self.visit_binary(|e1, e2| e1 % e2),
1106
Pow => self.visit_binary::<Expr>(Expr::pow),
1107
Round => {
1108
let args = extract_args(function)?;
1109
match args.len() {
1110
1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
1111
2 => self.try_visit_binary(|e, decimals| {
1112
Ok(e.round(match decimals {
1113
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1114
if n >= 0 { n as u32 } else {
1115
polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", args[1])
1116
}
1117
},
1118
_ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
1119
}, RoundMode::default()))
1120
}),
1121
_ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
1122
}
1123
},
1124
Sign => self.visit_unary(Expr::sign),
1125
Sqrt => self.visit_unary(Expr::sqrt),
1126
1127
// ----
1128
// Trig functions
1129
// ----
1130
Acos => self.visit_unary(Expr::arccos),
1131
AcosD => self.visit_unary(|e| e.arccos().degrees()),
1132
Asin => self.visit_unary(Expr::arcsin),
1133
AsinD => self.visit_unary(|e| e.arcsin().degrees()),
1134
Atan => self.visit_unary(Expr::arctan),
1135
Atan2 => self.visit_binary(Expr::arctan2),
1136
Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
1137
AtanD => self.visit_unary(|e| e.arctan().degrees()),
1138
Cos => self.visit_unary(Expr::cos),
1139
CosD => self.visit_unary(|e| e.radians().cos()),
1140
Cot => self.visit_unary(Expr::cot),
1141
CotD => self.visit_unary(|e| e.radians().cot()),
1142
Degrees => self.visit_unary(Expr::degrees),
1143
Radians => self.visit_unary(Expr::radians),
1144
Sin => self.visit_unary(Expr::sin),
1145
SinD => self.visit_unary(|e| e.radians().sin()),
1146
Tan => self.visit_unary(Expr::tan),
1147
TanD => self.visit_unary(|e| e.radians().tan()),
1148
1149
// ----
1150
// Conditional functions
1151
// ----
1152
Coalesce => self.visit_variadic(coalesce),
1153
Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
1154
If => {
1155
let args = extract_args(function)?;
1156
match args.len() {
1157
3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
1158
Ok(when(cond).then(expr1).otherwise(expr2))
1159
}),
1160
_ => {
1161
polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
1162
)
1163
},
1164
}
1165
},
1166
IfNull => {
1167
let args = extract_args(function)?;
1168
match args.len() {
1169
2 => self.visit_variadic(coalesce),
1170
_ => {
1171
polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
1172
},
1173
}
1174
},
1175
Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
1176
NullIf => {
1177
let args = extract_args(function)?;
1178
match args.len() {
1179
2 => self.visit_binary(|l: Expr, r: Expr| {
1180
when(l.clone().eq(r))
1181
.then(lit(LiteralValue::untyped_null()))
1182
.otherwise(l)
1183
}),
1184
_ => {
1185
polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
1186
},
1187
}
1188
},
1189
1190
// ----
1191
// Date functions
1192
// ----
1193
DatePart => self.try_visit_binary(|part, e| {
1194
match part {
1195
Expr::Literal(p) if p.extract_str().is_some() => {
1196
let p = p.extract_str().unwrap();
1197
// note: 'DATE_PART' and 'EXTRACT' are minor syntactic
1198
// variations on otherwise identical functionality
1199
parse_extract_date_part(
1200
e,
1201
&DateTimeField::Custom(Ident {
1202
value: p.to_string(),
1203
quote_style: None,
1204
span: Span::empty(),
1205
}),
1206
)
1207
},
1208
_ => {
1209
polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
1210
},
1211
}
1212
}),
1213
Strftime => {
1214
let args = extract_args(function)?;
1215
match args.len() {
1216
2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
1217
_ => {
1218
polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
1219
},
1220
}
1221
},
1222
1223
// ----
1224
// String functions
1225
// ----
1226
BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
1227
Concat => {
1228
let args = extract_args(function)?;
1229
if args.is_empty() {
1230
polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
1231
} else {
1232
self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
1233
}
1234
},
1235
ConcatWS => {
1236
let args = extract_args(function)?;
1237
if args.len() < 2 {
1238
polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
1239
} else {
1240
self.try_visit_variadic(|exprs: &[Expr]| {
1241
match &exprs[0] {
1242
Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
1243
_ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
1244
}
1245
})
1246
}
1247
},
1248
Date => {
1249
let args = extract_args(function)?;
1250
match args.len() {
1251
1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
1252
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
1253
_ => {
1254
polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
1255
},
1256
}
1257
},
1258
EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
1259
#[cfg(feature = "nightly")]
1260
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
1261
Left => self.try_visit_binary(|e, length| {
1262
Ok(match length {
1263
Expr::Literal(lv) if lv.is_null() => lit(lv),
1264
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
1265
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1266
let len = if n > 0 {
1267
lit(n)
1268
} else {
1269
(e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
1270
};
1271
e.str().slice(lit(0), len)
1272
},
1273
Expr::Literal(v) => {
1274
polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
1275
},
1276
_ => when(length.clone().gt_eq(lit(0)))
1277
.then(e.clone().str().slice(lit(0), length.clone().abs()))
1278
.otherwise(e.clone().str().slice(
1279
lit(0),
1280
(e.str().len_chars() + length.clone()).clip_min(lit(0)),
1281
)),
1282
})
1283
}),
1284
Length => self.visit_unary(|e| e.str().len_chars()),
1285
Lower => self.visit_unary(|e| e.str().to_lowercase()),
1286
LTrim => {
1287
let args = extract_args(function)?;
1288
match args.len() {
1289
1 => self.visit_unary(|e| {
1290
e.str().strip_chars_start(lit(LiteralValue::untyped_null()))
1291
}),
1292
2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)),
1293
_ => {
1294
polars_bail!(SQLSyntax: "LTRIM expects 1-2 arguments (found {})", args.len())
1295
},
1296
}
1297
},
1298
Normalize => {
1299
let args = extract_args(function)?;
1300
match args.len() {
1301
1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
1302
2 => {
1303
let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
1304
value: s,
1305
quote_style: None,
1306
span: _,
1307
})) = args[1]
1308
{
1309
match s.to_uppercase().as_str() {
1310
"NFC" => UnicodeForm::NFC,
1311
"NFD" => UnicodeForm::NFD,
1312
"NFKC" => UnicodeForm::NFKC,
1313
"NFKD" => UnicodeForm::NFKD,
1314
_ => {
1315
polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
1316
},
1317
}
1318
} else {
1319
polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
1320
};
1321
self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
1322
},
1323
_ => {
1324
polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
1325
},
1326
}
1327
},
1328
OctetLength => self.visit_unary(|e| e.str().len_bytes()),
1329
StrPos => {
1330
// note: SQL is 1-indexed; returns zero if no match found
1331
self.visit_binary(|expr, substring| {
1332
(expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
1333
})
1334
},
1335
RegexpLike => {
1336
let args = extract_args(function)?;
1337
match args.len() {
1338
2 => self.visit_binary(|e, s| e.str().contains(s, true)),
1339
3 => self.try_visit_ternary(|e, pat, flags| {
1340
Ok(e.str().contains(
1341
match (pat, flags) {
1342
(Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
1343
let s = s_lv.extract_str().unwrap();
1344
let f = f_lv.extract_str().unwrap();
1345
if f.is_empty() {
1346
polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
1347
};
1348
lit(format!("(?{f}){s}"))
1349
},
1350
_ => {
1351
polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
1352
},
1353
},
1354
true))
1355
}),
1356
_ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
1357
}
1358
},
1359
Replace => {
1360
let args = extract_args(function)?;
1361
match args.len() {
1362
3 => self
1363
.try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
1364
_ => {
1365
polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
1366
},
1367
}
1368
},
1369
Reverse => self.visit_unary(|e| e.str().reverse()),
1370
Right => self.try_visit_binary(|e, length| {
1371
Ok(match length {
1372
Expr::Literal(lv) if lv.is_null() => lit(lv),
1373
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
1374
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1375
let n: i64 = n.try_into().unwrap();
1376
let offset = if n < 0 {
1377
lit(n.abs())
1378
} else {
1379
e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
1380
};
1381
e.str().slice(offset, lit(LiteralValue::untyped_null()))
1382
},
1383
Expr::Literal(v) => {
1384
polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
1385
},
1386
_ => when(length.clone().lt(lit(0)))
1387
.then(
1388
e.clone()
1389
.str()
1390
.slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
1391
)
1392
.otherwise(e.clone().str().slice(
1393
e.str().len_chars().cast(DataType::Int32) - length.clone(),
1394
lit(LiteralValue::untyped_null()),
1395
)),
1396
})
1397
}),
1398
RTrim => {
1399
let args = extract_args(function)?;
1400
match args.len() {
1401
1 => self.visit_unary(|e| {
1402
e.str().strip_chars_end(lit(LiteralValue::untyped_null()))
1403
}),
1404
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
1405
_ => {
1406
polars_bail!(SQLSyntax: "RTRIM expects 1-2 arguments (found {})", args.len())
1407
},
1408
}
1409
},
1410
SplitPart => {
1411
let args = extract_args(function)?;
1412
match args.len() {
1413
3 => self.try_visit_ternary(|e, sep, idx| {
1414
let idx = adjust_one_indexed_param(idx, true);
1415
Ok(when(e.clone().is_not_null())
1416
.then(
1417
e.clone()
1418
.str()
1419
.split(sep)
1420
.list()
1421
.get(idx, true)
1422
.fill_null(lit("")),
1423
)
1424
.otherwise(e))
1425
}),
1426
_ => {
1427
polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
1428
},
1429
}
1430
},
1431
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
1432
StringToArray => {
1433
let args = extract_args(function)?;
1434
match args.len() {
1435
2 => self.visit_binary(|e, sep| e.str().split(sep)),
1436
_ => {
1437
polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
1438
},
1439
}
1440
},
1441
Strptime => {
1442
let args = extract_args(function)?;
1443
match args.len() {
1444
2 => self.visit_binary(|e, fmt: String| {
1445
e.str().strptime(
1446
DataType::Datetime(TimeUnit::Microseconds, None),
1447
StrptimeOptions {
1448
format: Some(fmt.into()),
1449
..Default::default()
1450
},
1451
lit("latest"),
1452
)
1453
}),
1454
_ => {
1455
polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
1456
},
1457
}
1458
},
1459
Time => {
1460
let args = extract_args(function)?;
1461
match args.len() {
1462
1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
1463
2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
1464
_ => {
1465
polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
1466
},
1467
}
1468
},
1469
Timestamp => {
1470
let args = extract_args(function)?;
1471
match args.len() {
1472
1 => self.visit_unary(|e| {
1473
e.str()
1474
.to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
1475
}),
1476
2 => self
1477
.visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
1478
_ => {
1479
polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
1480
},
1481
}
1482
},
1483
Substring => {
1484
let args = extract_args(function)?;
1485
match args.len() {
1486
// note: SQL is 1-indexed, hence the need for adjustments
1487
2 => self.try_visit_binary(|e, start| {
1488
Ok(match start {
1489
Expr::Literal(lv) if lv.is_null() => lit(lv),
1490
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1491
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
1492
Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1493
_ => start.clone() + lit(1),
1494
})
1495
}),
1496
3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
1497
Ok(match (start.clone(), length.clone()) {
1498
(Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1499
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1500
polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
1501
},
1502
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
1503
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
1504
e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
1505
},
1506
(Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1507
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1508
polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
1509
},
1510
_ => {
1511
let adjusted_start = start - lit(1);
1512
when(adjusted_start.clone().lt(lit(0)))
1513
.then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
1514
.otherwise(e.str().slice(adjusted_start, length))
1515
}
1516
})
1517
}),
1518
_ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
1519
}
1520
},
1521
Upper => self.visit_unary(|e| e.str().to_uppercase()),
1522
1523
// ----
1524
// Aggregate functions
1525
// ----
1526
Avg => self.visit_unary(Expr::mean),
1527
Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
1528
Count => self.visit_count(),
1529
CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
1530
CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
1531
First => self.visit_unary(Expr::first),
1532
Last => self.visit_unary(Expr::last),
1533
Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
1534
Median => self.visit_unary(Expr::median),
1535
QuantileCont => {
1536
let args = extract_args(function)?;
1537
match args.len() {
1538
2 => self.try_visit_binary(|e, q| {
1539
let value = match q {
1540
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1541
if (0.0..=1.0).contains(&f) {
1542
Expr::from(f)
1543
} else {
1544
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1545
}
1546
},
1547
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1548
if (0..=1).contains(&n) {
1549
Expr::from(n as f64)
1550
} else {
1551
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1552
}
1553
},
1554
_ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1])
1555
};
1556
Ok(e.quantile(value, QuantileMethod::Linear))
1557
}),
1558
_ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()),
1559
}
1560
},
1561
QuantileDisc => {
1562
let args = extract_args(function)?;
1563
match args.len() {
1564
2 => self.try_visit_binary(|e, q| {
1565
let value = match q {
1566
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1567
if (0.0..=1.0).contains(&f) {
1568
Expr::from(f)
1569
} else {
1570
polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1571
}
1572
},
1573
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1574
if (0..=1).contains(&n) {
1575
Expr::from(n as f64)
1576
} else {
1577
polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1578
}
1579
},
1580
_ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1])
1581
};
1582
Ok(e.quantile(value, QuantileMethod::Equiprobable))
1583
}),
1584
_ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()),
1585
}
1586
},
1587
Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
1588
StdDev => self.visit_unary(|e| e.std(1)),
1589
Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
1590
Variance => self.visit_unary(|e| e.var(1)),
1591
1592
// ----
1593
// Array functions
1594
// ----
1595
ArrayAgg => self.visit_arr_agg(),
1596
ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
1597
ArrayGet => {
1598
// note: SQL is 1-indexed, not 0-indexed
1599
self.visit_binary(|e, idx: Expr| {
1600
let idx = adjust_one_indexed_param(idx, true);
1601
e.list().get(idx, true)
1602
})
1603
},
1604
ArrayLength => self.visit_unary(|e| e.list().len()),
1605
ArrayMax => self.visit_unary(|e| e.list().max()),
1606
ArrayMean => self.visit_unary(|e| e.list().mean()),
1607
ArrayMin => self.visit_unary(|e| e.list().min()),
1608
ArrayReverse => self.visit_unary(|e| e.list().reverse()),
1609
ArraySum => self.visit_unary(|e| e.list().sum()),
1610
ArrayToString => self.visit_arr_to_string(),
1611
ArrayUnique => self.visit_unary(|e| e.list().unique_stable()),
1612
Explode => self.visit_unary(|e| {
1613
e.explode(ExplodeOptions {
1614
empty_as_null: true,
1615
keep_nulls: true,
1616
})
1617
}),
1618
1619
// ----
1620
// Column selection
1621
// ----
1622
Columns => {
1623
let active_schema = self.active_schema;
1624
self.try_visit_unary(|e: Expr| match e {
1625
Expr::Literal(lv) if lv.extract_str().is_some() => {
1626
let pat = lv.extract_str().unwrap();
1627
if pat == "*" {
1628
polars_bail!(
1629
SQLSyntax: "COLUMNS('*') is not a valid regex; \
1630
did you mean COLUMNS(*)?"
1631
)
1632
};
1633
let pat = match pat {
1634
_ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
1635
_ if pat.starts_with('^') => format!("{pat}.*$"),
1636
_ if pat.ends_with('$') => format!("^.*{pat}"),
1637
_ => format!("^.*{pat}.*$"),
1638
};
1639
if let Some(active_schema) = &active_schema {
1640
let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
1641
let col_names = active_schema
1642
.iter_names()
1643
.filter(|name| rx.is_match(name))
1644
.cloned()
1645
.collect::<Vec<_>>();
1646
1647
Ok(if col_names.len() == 1 {
1648
col(col_names.into_iter().next().unwrap())
1649
} else {
1650
cols(col_names).as_expr()
1651
})
1652
} else {
1653
Ok(col(pat.as_str()))
1654
}
1655
},
1656
Expr::Selector(s) => Ok(s.as_expr()),
1657
_ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
1658
})
1659
},
1660
1661
// ----
1662
// Window functions
1663
// ----
1664
FirstValue => self.visit_unary(Expr::first),
1665
LastValue => {
1666
// With the default window frame (ROWS UNBOUNDED PRECEDING TO CURRENT ROW),
1667
// LAST_VALUE returns the last value from the start of the partition up
1668
// to the current row - which is simply the current row's value.
1669
let args = extract_args(function)?;
1670
match args.as_slice() {
1671
[FunctionArgExpr::Expr(sql_expr)] => {
1672
parse_sql_expr(sql_expr, self.ctx, self.active_schema)
1673
},
1674
_ => polars_bail!(
1675
SQLSyntax: "LAST_VALUE expects exactly 1 argument (found {})",
1676
args.len()
1677
),
1678
}
1679
},
1680
Lag => self.visit_window_offset_function(1),
1681
Lead => self.visit_window_offset_function(-1),
1682
#[cfg(feature = "rank")]
1683
Rank | DenseRank => {
1684
let (func_name, rank_method) = match function_name {
1685
Rank => ("RANK", RankMethod::Min),
1686
DenseRank => ("DENSE_RANK", RankMethod::Dense),
1687
_ => unreachable!(),
1688
};
1689
let args = extract_args(function)?;
1690
if !args.is_empty() {
1691
polars_bail!(SQLSyntax: "{} expects 0 arguments (found {})", func_name, args.len());
1692
}
1693
let window_spec = match &self.func.over {
1694
Some(WindowType::WindowSpec(spec)) if !spec.order_by.is_empty() => spec,
1695
_ => {
1696
polars_bail!(SQLSyntax: "{} requires an OVER clause with ORDER BY", func_name)
1697
},
1698
};
1699
let (order_exprs, all_desc) =
1700
self.parse_order_by_in_window(&window_spec.order_by)?;
1701
let rank_expr = if order_exprs.len() == 1 {
1702
order_exprs[0].clone().rank(
1703
RankOptions {
1704
method: rank_method,
1705
descending: all_desc,
1706
},
1707
None,
1708
)
1709
} else {
1710
as_struct(order_exprs).rank(
1711
RankOptions {
1712
method: rank_method,
1713
descending: all_desc,
1714
},
1715
None,
1716
)
1717
};
1718
self.apply_window_spec(rank_expr, &self.func.over)
1719
},
1720
RowNumber => {
1721
let args = extract_args(function)?;
1722
if !args.is_empty() {
1723
polars_bail!(SQLSyntax: "ROW_NUMBER expects 0 arguments (found {})", args.len());
1724
}
1725
// note: SQL is 1-indexed
1726
let row_num_expr = int_range(lit(0i64), len(), 1, DataType::UInt32) + lit(1u32);
1727
self.apply_window_spec(row_num_expr, &self.func.over)
1728
},
1729
1730
// ----
1731
// User-defined
1732
// ----
1733
Udf(func_name) => self.visit_udf(&func_name),
1734
}
1735
}
1736
1737
fn visit_window_offset_function(&mut self, offset_multiplier: i64) -> PolarsResult<Expr> {
1738
// LAG/LEAD require an OVER clause
1739
if self.func.over.is_none() {
1740
polars_bail!(SQLSyntax: "{} requires an OVER clause", self.func.name);
1741
}
1742
1743
// LAG/LEAD require ORDER BY in the OVER clause
1744
let window_type = self.func.over.as_ref().unwrap();
1745
let window_spec = self.resolve_window_spec(window_type)?;
1746
if window_spec.order_by.is_empty() {
1747
polars_bail!(SQLSyntax: "{} requires an ORDER BY in the OVER clause", self.func.name);
1748
}
1749
1750
let args = extract_args(self.func)?;
1751
1752
match args.as_slice() {
1753
[FunctionArgExpr::Expr(sql_expr)] => {
1754
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1755
Ok(expr.shift(offset_multiplier.into()))
1756
},
1757
[FunctionArgExpr::Expr(sql_expr), FunctionArgExpr::Expr(offset_expr)] => {
1758
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1759
let offset = parse_sql_expr(offset_expr, self.ctx, self.active_schema)?;
1760
if let Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) = offset {
1761
if n <= 0 {
1762
polars_bail!(SQLSyntax: "offset must be positive (found {})", n)
1763
}
1764
Ok(expr.shift((offset_multiplier * n as i64).into()))
1765
} else {
1766
polars_bail!(SQLSyntax: "offset must be an integer (found {:?})", offset)
1767
}
1768
},
1769
_ => polars_bail!(SQLSyntax: "{} expects 1 or 2 arguments (found {})", self.func.name, args.len()),
1770
}.and_then(|e| self.apply_window_spec(e, &self.func.over))
1771
}
1772
1773
fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
1774
let args = extract_args(self.func)?
1775
.into_iter()
1776
.map(|arg| {
1777
if let FunctionArgExpr::Expr(e) = arg {
1778
parse_sql_expr(e, self.ctx, self.active_schema)
1779
} else {
1780
polars_bail!(SQLInterface: "only expressions are supported in UDFs")
1781
}
1782
})
1783
.collect::<PolarsResult<Vec<_>>>()?;
1784
1785
Ok(self
1786
.ctx
1787
.function_registry
1788
.get_udf(func_name)?
1789
.ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
1790
.call(args))
1791
}
1792
1793
/// Validate window frame specifications.
1794
///
1795
/// Polars only supports ROWS frame semantics, and does
1796
/// not currently support customising the window.
1797
///
1798
/// **Supported Frame Spec**
1799
/// - `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`
1800
///
1801
/// **Unsupported Frame Spec**
1802
/// - `RANGE ...` (peer group semantics not implemented)
1803
/// - `GROUPS ...` (peer group semantics not implemented)
1804
/// - `ROWS` with other bounds (e.g., `<n> PRECEDING`, `FOLLOWING`, etc)
1805
fn validate_window_frame(&self, window_frame: &Option<WindowFrame>) -> PolarsResult<()> {
1806
if let Some(frame) = window_frame {
1807
match frame.units {
1808
WindowFrameUnits::Range => {
1809
polars_bail!(
1810
SQLInterface:
1811
"RANGE-based window frames are not supported"
1812
);
1813
},
1814
WindowFrameUnits::Groups => {
1815
polars_bail!(
1816
SQLInterface:
1817
"GROUPS-based window frames are not supported"
1818
);
1819
},
1820
WindowFrameUnits::Rows => {
1821
if !matches!(
1822
(&frame.start_bound, &frame.end_bound),
1823
(
1824
WindowFrameBound::Preceding(None), // UNBOUNDED PRECEDING
1825
None | Some(WindowFrameBound::CurrentRow) // CURRENT ROW
1826
)
1827
) {
1828
polars_bail!(
1829
SQLInterface:
1830
"only 'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW' is currently supported; found 'ROWS BETWEEN {} AND {}'",
1831
frame.start_bound,
1832
frame.end_bound.as_ref().map_or("CURRENT ROW", |b| {
1833
match b {
1834
WindowFrameBound::CurrentRow => "CURRENT ROW",
1835
WindowFrameBound::Preceding(_) => "N PRECEDING",
1836
WindowFrameBound::Following(_) => "N FOLLOWING",
1837
}
1838
})
1839
);
1840
}
1841
},
1842
}
1843
}
1844
Ok(())
1845
}
1846
1847
/// Window specs that map to cumulative functions.
1848
///
1849
/// Converts SQL window functions with ORDER BY to compatible cumulative ops:
1850
/// - `SUM(a) OVER (ORDER BY b)` → `a.cum_sum().over(order_by=b)`
1851
/// - `MAX(a) OVER (ORDER BY b)` → `a.cum_max().over(order_by=b)`
1852
/// - `MIN(a) OVER (ORDER BY b)` → `a.cum_min().over(order_by=b)`
1853
///
1854
/// ROWS vs RANGE Semantics (show default behaviour if no frame spec):
1855
///
1856
/// **Polars (ROWS)**
1857
/// Each row gets its own cumulative value row-by-row.
1858
/// ```text
1859
/// Data: [(A,X,10), (A,X,15), (A,Y,20)]
1860
/// Query: SUM(value) OVER (ORDER BY category, subcategory)
1861
/// Result: [10, 25, 45] ← row-by-row cumulative
1862
/// ```
1863
///
1864
/// **SQL (RANGE)**
1865
/// Rows with identical ORDER BY values (peers) get the same result.
1866
/// ```text
1867
/// Same data, query with RANGE (eg: using a relational DB):
1868
/// Result: [25, 25, 45] ← both (A,X) rows get 25
1869
/// ```
1870
fn apply_cumulative_window(
1871
&mut self,
1872
f: impl Fn(Expr) -> Expr,
1873
cumulative_fn: impl Fn(Expr, bool) -> Expr,
1874
WindowSpec {
1875
partition_by,
1876
order_by,
1877
window_frame,
1878
..
1879
}: &WindowSpec,
1880
) -> PolarsResult<Expr> {
1881
self.validate_window_frame(window_frame)?;
1882
1883
if !order_by.is_empty() {
1884
// Extract ORDER BY exprs and sort direction
1885
let (order_by_exprs, all_desc) = self.parse_order_by_in_window(order_by)?;
1886
1887
// Get the base expr/column
1888
let args = extract_args(self.func)?;
1889
let base_expr = match args.as_slice() {
1890
[FunctionArgExpr::Expr(sql_expr)] => {
1891
parse_sql_expr(sql_expr, self.ctx, self.active_schema)?
1892
},
1893
_ => return self.not_supported_error(),
1894
};
1895
let partition_by_exprs = if partition_by.is_empty() {
1896
None
1897
} else {
1898
Some(
1899
partition_by
1900
.iter()
1901
.map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
1902
.collect::<PolarsResult<Vec<_>>>()?,
1903
)
1904
};
1905
1906
// Apply cumulative function and wrap with window spec
1907
let cumulative_expr = cumulative_fn(base_expr, false);
1908
let sort_opts = SortOptions::default().with_order_descending(all_desc);
1909
cumulative_expr.over_with_options(
1910
partition_by_exprs,
1911
Some((order_by_exprs, sort_opts)),
1912
Default::default(),
1913
)
1914
} else {
1915
self.visit_unary(f)
1916
}
1917
}
1918
1919
fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1920
self.try_visit_unary(|e| Ok(f(e)))
1921
}
1922
1923
fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
1924
let args = extract_args(self.func)?;
1925
match args.as_slice() {
1926
[FunctionArgExpr::Expr(sql_expr)] => {
1927
f(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?)
1928
},
1929
[FunctionArgExpr::Wildcard] => f(parse_sql_expr(
1930
&SQLExpr::Wildcard(AttachedToken::empty()),
1931
self.ctx,
1932
self.active_schema,
1933
)?),
1934
_ => self.not_supported_error(),
1935
}
1936
.and_then(|e| self.apply_window_spec(e, &self.func.over))
1937
}
1938
1939
/// Resolve a WindowType to a concrete WindowSpec (handles named window references)
1940
fn resolve_window_spec(&self, window_type: &WindowType) -> PolarsResult<WindowSpec> {
1941
match window_type {
1942
WindowType::WindowSpec(spec) => Ok(spec.clone()),
1943
WindowType::NamedWindow(name) => self
1944
.ctx
1945
.named_windows
1946
.get(&name.value)
1947
.cloned()
1948
.ok_or_else(|| {
1949
polars_err!(
1950
SQLInterface:
1951
"named window '{}' was not found",
1952
name.value
1953
)
1954
}),
1955
}
1956
}
1957
1958
/// Some functions have cumulative equivalents that can be applied to window specs
1959
/// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false)
1960
fn visit_unary_with_opt_cumulative(
1961
&mut self,
1962
f: impl Fn(Expr) -> Expr,
1963
cumulative_fn: impl Fn(Expr, bool) -> Expr,
1964
) -> PolarsResult<Expr> {
1965
match self.func.over.as_ref() {
1966
Some(window_type) => {
1967
let spec = self.resolve_window_spec(window_type)?;
1968
self.apply_cumulative_window(f, cumulative_fn, &spec)
1969
},
1970
None => self.visit_unary(f),
1971
}
1972
}
1973
1974
fn visit_binary<Arg: FromSQLExpr>(
1975
&mut self,
1976
f: impl Fn(Expr, Arg) -> Expr,
1977
) -> PolarsResult<Expr> {
1978
self.try_visit_binary(|e, a| Ok(f(e, a)))
1979
}
1980
1981
fn try_visit_binary<Arg: FromSQLExpr>(
1982
&mut self,
1983
f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
1984
) -> PolarsResult<Expr> {
1985
let args = extract_args(self.func)?;
1986
match args.as_slice() {
1987
[
1988
FunctionArgExpr::Expr(sql_expr1),
1989
FunctionArgExpr::Expr(sql_expr2),
1990
] => {
1991
let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1992
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1993
f(expr1, expr2)
1994
},
1995
_ => self.not_supported_error(),
1996
}
1997
}
1998
1999
fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
2000
self.try_visit_variadic(|e| Ok(f(e)))
2001
}
2002
2003
fn try_visit_variadic(
2004
&mut self,
2005
f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
2006
) -> PolarsResult<Expr> {
2007
let args = extract_args(self.func)?;
2008
let mut expr_args = vec![];
2009
for arg in args {
2010
if let FunctionArgExpr::Expr(sql_expr) = arg {
2011
expr_args.push(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?);
2012
} else {
2013
return self.not_supported_error();
2014
};
2015
}
2016
f(&expr_args)
2017
}
2018
2019
fn try_visit_ternary<Arg: FromSQLExpr>(
2020
&mut self,
2021
f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
2022
) -> PolarsResult<Expr> {
2023
let args = extract_args(self.func)?;
2024
match args.as_slice() {
2025
[
2026
FunctionArgExpr::Expr(sql_expr1),
2027
FunctionArgExpr::Expr(sql_expr2),
2028
FunctionArgExpr::Expr(sql_expr3),
2029
] => {
2030
let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
2031
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
2032
let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
2033
f(expr1, expr2, expr3)
2034
},
2035
_ => self.not_supported_error(),
2036
}
2037
}
2038
2039
fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
2040
let args = extract_args(self.func)?;
2041
if !args.is_empty() {
2042
return self.not_supported_error();
2043
}
2044
Ok(f())
2045
}
2046
2047
fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
2048
let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
2049
match args.as_slice() {
2050
[FunctionArgExpr::Expr(sql_expr)] => {
2051
let mut base = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
2052
let mut order_by_clause = None;
2053
let mut limit_clause = None;
2054
for clause in &clauses {
2055
match clause {
2056
FunctionArgumentClause::OrderBy(order_exprs) => {
2057
order_by_clause = Some(order_exprs.as_slice());
2058
},
2059
FunctionArgumentClause::Limit(limit_expr) => {
2060
limit_clause = Some(limit_expr);
2061
},
2062
_ => {},
2063
}
2064
}
2065
if !is_distinct {
2066
// No DISTINCT: apply ORDER BY normally
2067
if let Some(order_by) = order_by_clause {
2068
base = self.apply_order_by(base, order_by)?;
2069
}
2070
} else {
2071
// DISTINCT: apply unique, then sort the result
2072
base = base.unique_stable();
2073
if let Some(order_by) = order_by_clause {
2074
base = self.apply_order_by_to_distinct_array(base, order_by, sql_expr)?;
2075
}
2076
}
2077
if let Some(limit_expr) = limit_clause {
2078
let limit = parse_sql_expr(limit_expr, self.ctx, self.active_schema)?;
2079
match limit {
2080
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n >= 0 => {
2081
base = base.head(Some(n as usize))
2082
},
2083
_ => {
2084
polars_bail!(SQLSyntax: "LIMIT in ARRAY_AGG must be a positive integer")
2085
},
2086
};
2087
}
2088
Ok(base.implode())
2089
},
2090
_ => {
2091
polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
2092
},
2093
}
2094
}
2095
2096
fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
2097
let args = extract_args(self.func)?;
2098
match args.len() {
2099
2 => self.try_visit_binary(|e, sep| {
2100
Ok(e.cast(DataType::List(Box::from(DataType::String)))
2101
.list()
2102
.join(sep, true))
2103
}),
2104
#[cfg(feature = "list_eval")]
2105
3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
2106
Expr::Literal(lv) if lv.extract_str().is_some() => {
2107
Ok(if lv.extract_str().unwrap().is_empty() {
2108
e.cast(DataType::List(Box::from(DataType::String)))
2109
.list()
2110
.join(sep, true)
2111
} else {
2112
e.cast(DataType::List(Box::from(DataType::String)))
2113
.list()
2114
.eval(element().fill_null(lit(lv.extract_str().unwrap())))
2115
.list()
2116
.join(sep, false)
2117
})
2118
},
2119
_ => {
2120
polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
2121
},
2122
}),
2123
_ => {
2124
polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
2125
},
2126
}
2127
}
2128
2129
fn visit_count(&mut self) -> PolarsResult<Expr> {
2130
let (args, is_distinct) = extract_args_distinct(self.func)?;
2131
2132
// Window function with an ORDER BY clause?
2133
let has_order_by = match &self.func.over {
2134
Some(WindowType::WindowSpec(spec)) => !spec.order_by.is_empty(),
2135
_ => false,
2136
};
2137
if has_order_by && !is_distinct {
2138
if let Some(WindowType::WindowSpec(spec)) = &self.func.over {
2139
self.validate_window_frame(&spec.window_frame)?;
2140
2141
match args.as_slice() {
2142
[FunctionArgExpr::Wildcard] | [] => {
2143
// COUNT(*) with ORDER BY -> map to `int_range`
2144
let (order_by_exprs, all_desc) =
2145
self.parse_order_by_in_window(&spec.order_by)?;
2146
let partition_by_exprs = if spec.partition_by.is_empty() {
2147
None
2148
} else {
2149
Some(
2150
spec.partition_by
2151
.iter()
2152
.map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2153
.collect::<PolarsResult<Vec<_>>>()?,
2154
)
2155
};
2156
let sort_opts = SortOptions::default().with_order_descending(all_desc);
2157
let row_number = int_range(lit(0), len(), 1, DataType::Int64).add(lit(1)); // SQL is 1-indexed
2158
2159
return row_number.over_with_options(
2160
partition_by_exprs,
2161
Some((order_by_exprs, sort_opts)),
2162
Default::default(),
2163
);
2164
},
2165
[FunctionArgExpr::Expr(_)] => {
2166
// COUNT(column) with ORDER BY -> use cum_count
2167
return self.visit_unary_with_opt_cumulative(
2168
|e| e.count(),
2169
|e, reverse| e.cum_count(reverse),
2170
);
2171
},
2172
_ => {},
2173
}
2174
}
2175
}
2176
let count_expr = match (is_distinct, args.as_slice()) {
2177
// COUNT(*), COUNT()
2178
(false, [FunctionArgExpr::Wildcard] | []) => len(),
2179
// COUNT(col)
2180
(false, [FunctionArgExpr::Expr(sql_expr)]) => {
2181
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
2182
expr.count()
2183
},
2184
// COUNT(DISTINCT col)
2185
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
2186
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
2187
expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
2188
},
2189
_ => self.not_supported_error()?,
2190
};
2191
self.apply_window_spec(count_expr, &self.func.over)
2192
}
2193
2194
fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
2195
let mut by = Vec::with_capacity(order_by.len());
2196
let mut descending = Vec::with_capacity(order_by.len());
2197
let mut nulls_last = Vec::with_capacity(order_by.len());
2198
2199
for ob in order_by {
2200
// Note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise
2201
// https://www.postgresql.org/docs/current/queries-order.html
2202
let desc_order = !ob.options.asc.unwrap_or(true);
2203
by.push(parse_sql_expr(&ob.expr, self.ctx, self.active_schema)?);
2204
nulls_last.push(!ob.options.nulls_first.unwrap_or(desc_order));
2205
descending.push(desc_order);
2206
}
2207
Ok(expr.sort_by(
2208
by,
2209
SortMultipleOptions::default()
2210
.with_order_descending_multi(descending)
2211
.with_nulls_last_multi(nulls_last),
2212
))
2213
}
2214
2215
fn apply_order_by_to_distinct_array(
2216
&mut self,
2217
expr: Expr,
2218
order_by: &[OrderByExpr],
2219
base_sql_expr: &SQLExpr,
2220
) -> PolarsResult<Expr> {
2221
// If ORDER BY references the base expression, use .sort() directly
2222
if order_by.len() == 1 && order_by[0].expr == *base_sql_expr {
2223
let desc_order = !order_by[0].options.asc.unwrap_or(true);
2224
let nulls_last = !order_by[0].options.nulls_first.unwrap_or(desc_order);
2225
return Ok(expr.sort(
2226
SortOptions::default()
2227
.with_order_descending(desc_order)
2228
.with_nulls_last(nulls_last)
2229
.with_maintain_order(true),
2230
));
2231
}
2232
// Otherwise, fall back to `sort_by` (may need to handle further edge-cases later)
2233
self.apply_order_by(expr, order_by)
2234
}
2235
2236
/// Parse ORDER BY (in OVER clause), validating uniform direction.
2237
fn parse_order_by_in_window(
2238
&mut self,
2239
order_by: &[OrderByExpr],
2240
) -> PolarsResult<(Vec<Expr>, bool)> {
2241
if order_by.is_empty() {
2242
return Ok((Vec::new(), false));
2243
}
2244
// Parse expressions and validate uniform direction
2245
let all_ascending = order_by[0].options.asc.unwrap_or(true);
2246
let mut exprs = Vec::with_capacity(order_by.len());
2247
for o in order_by {
2248
if all_ascending != o.options.asc.unwrap_or(true) {
2249
// TODO: mixed sort directions are not currently supported; we
2250
// need to enhance `over_with_options` to take SortMultipleOptions
2251
polars_bail!(
2252
SQLSyntax:
2253
"OVER does not (yet) support mixed asc/desc directions for ORDER BY"
2254
)
2255
}
2256
let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
2257
exprs.push(expr);
2258
}
2259
Ok((exprs, !all_ascending))
2260
}
2261
2262
fn apply_window_spec(
2263
&mut self,
2264
expr: Expr,
2265
window_type: &Option<WindowType>,
2266
) -> PolarsResult<Expr> {
2267
let Some(window_type) = window_type else {
2268
return Ok(expr);
2269
};
2270
let window_spec = self.resolve_window_spec(window_type)?;
2271
self.validate_window_frame(&window_spec.window_frame)?;
2272
2273
let partition_by = if window_spec.partition_by.is_empty() {
2274
None
2275
} else {
2276
Some(
2277
window_spec
2278
.partition_by
2279
.iter()
2280
.map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2281
.collect::<PolarsResult<Vec<_>>>()?,
2282
)
2283
};
2284
let order_by = if window_spec.order_by.is_empty() {
2285
None
2286
} else {
2287
let (order_exprs, all_desc) = self.parse_order_by_in_window(&window_spec.order_by)?;
2288
let sort_opts = SortOptions::default().with_order_descending(all_desc);
2289
Some((order_exprs, sort_opts))
2290
};
2291
2292
// Apply window spec
2293
Ok(match (partition_by, order_by) {
2294
(None, None) => expr,
2295
(Some(part), None) => expr.over(part),
2296
(part, Some(order)) => expr.over_with_options(part, Some(order), Default::default())?,
2297
})
2298
}
2299
2300
fn not_supported_error(&self) -> PolarsResult<Expr> {
2301
polars_bail!(
2302
SQLInterface:
2303
"no function matches the given name and arguments: `{}`",
2304
self.func.to_string()
2305
);
2306
}
2307
}
2308
2309
fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
2310
let (args, _, _) = _extract_func_args(func, false, false)?;
2311
Ok(args)
2312
}
2313
2314
fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
2315
let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
2316
Ok((args, is_distinct))
2317
}
2318
2319
fn extract_args_and_clauses(
2320
func: &SQLFunction,
2321
) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2322
_extract_func_args(func, true, true)
2323
}
2324
2325
fn _extract_func_args(
2326
func: &SQLFunction,
2327
get_distinct: bool,
2328
get_clauses: bool,
2329
) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2330
match &func.args {
2331
FunctionArguments::List(FunctionArgumentList {
2332
args,
2333
duplicate_treatment,
2334
clauses,
2335
}) => {
2336
let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
2337
if !(get_clauses || get_distinct) && is_distinct {
2338
polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
2339
} else if !get_clauses && !clauses.is_empty() {
2340
polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
2341
} else {
2342
let unpacked_args = args
2343
.iter()
2344
.map(|arg| match arg {
2345
FunctionArg::Named { arg, .. } => arg,
2346
FunctionArg::ExprNamed { arg, .. } => arg,
2347
FunctionArg::Unnamed(arg) => arg,
2348
})
2349
.collect();
2350
Ok((unpacked_args, is_distinct, clauses.clone()))
2351
}
2352
},
2353
FunctionArguments::Subquery { .. } => {
2354
Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
2355
},
2356
FunctionArguments::None => Ok((vec![], false, vec![])),
2357
}
2358
}
2359
2360
pub(crate) trait FromSQLExpr {
2361
fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2362
where
2363
Self: Sized;
2364
}
2365
2366
impl FromSQLExpr for f64 {
2367
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2368
where
2369
Self: Sized,
2370
{
2371
match expr {
2372
SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2373
SQLValue::Number(s, _) => s
2374
.parse()
2375
.map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
2376
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2377
},
2378
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2379
}
2380
}
2381
}
2382
2383
impl FromSQLExpr for bool {
2384
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2385
where
2386
Self: Sized,
2387
{
2388
match expr {
2389
SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2390
SQLValue::Boolean(v) => Ok(*v),
2391
_ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
2392
},
2393
_ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
2394
}
2395
}
2396
}
2397
2398
impl FromSQLExpr for String {
2399
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2400
where
2401
Self: Sized,
2402
{
2403
match expr {
2404
SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2405
SQLValue::SingleQuotedString(s) => Ok(s.clone()),
2406
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2407
},
2408
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2409
}
2410
}
2411
}
2412
2413
impl FromSQLExpr for StrptimeOptions {
2414
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2415
where
2416
Self: Sized,
2417
{
2418
match expr {
2419
SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2420
SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
2421
format: Some(PlSmallStr::from_str(s)),
2422
..StrptimeOptions::default()
2423
}),
2424
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2425
},
2426
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2427
}
2428
}
2429
}
2430
2431
impl FromSQLExpr for Expr {
2432
fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2433
where
2434
Self: Sized,
2435
{
2436
parse_sql_expr(expr, ctx, None)
2437
}
2438
}
2439
2440