Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-sql/src/functions.rs
6940 views
1
use std::ops::Sub;
2
3
use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
4
use polars_core::prelude::{
5
DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail, polars_err,
6
};
7
use polars_lazy::dsl::Expr;
8
use polars_ops::chunked_array::UnicodeForm;
9
use polars_ops::series::RoundMode;
10
use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal, when};
11
use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
12
use polars_plan::prelude::{StrptimeOptions, col, cols, lit};
13
use polars_utils::pl_str::PlSmallStr;
14
use sqlparser::ast::helpers::attached_token::AttachedToken;
15
use sqlparser::ast::{
16
DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
17
FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
18
OrderByExpr, Value as SQLValue, WindowSpec, WindowType,
19
};
20
use sqlparser::tokenizer::Span;
21
22
use crate::SQLContext;
23
use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
24
25
pub(crate) struct SQLFunctionVisitor<'a> {
26
pub(crate) func: &'a SQLFunction,
27
pub(crate) ctx: &'a mut SQLContext,
28
pub(crate) active_schema: Option<&'a Schema>,
29
}
30
31
/// SQL functions that are supported by Polars
32
pub(crate) enum PolarsSQLFunctions {
33
// ----
34
// Bitwise functions
35
// ----
36
/// SQL 'bit_and' function.
37
/// Returns the bitwise AND of the input expressions.
38
/// ```sql
39
/// SELECT BIT_AND(column_1, column_2) FROM df;
40
/// ```
41
BitAnd,
42
/// SQL 'bit_count' function.
43
/// Returns the number of set bits in the input expression.
44
/// ```sql
45
/// SELECT BIT_COUNT(column_1) FROM df;
46
/// ```
47
#[cfg(feature = "bitwise")]
48
BitCount,
49
/// SQL 'bit_or' function.
50
/// Returns the bitwise OR of the input expressions.
51
/// ```sql
52
/// SELECT BIT_OR(column_1, column_2) FROM df;
53
/// ```
54
BitOr,
55
/// SQL 'bit_xor' function.
56
/// Returns the bitwise XOR of the input expressions.
57
/// ```sql
58
/// SELECT BIT_XOR(column_1, column_2) FROM df;
59
/// ```
60
BitXor,
61
62
// ----
63
// Math functions
64
// ----
65
/// SQL 'abs' function.
66
/// Returns the absolute value of the input expression.
67
/// ```sql
68
/// SELECT ABS(column_1) FROM df;
69
/// ```
70
Abs,
71
/// SQL 'ceil' function.
72
/// Returns the nearest integer closest from zero.
73
/// ```sql
74
/// SELECT CEIL(column_1) FROM df;
75
/// ```
76
Ceil,
77
/// SQL 'div' function.
78
/// Returns the integer quotient of the division.
79
/// ```sql
80
/// SELECT DIV(column_1, 2) FROM df;
81
/// ```
82
Div,
83
/// SQL 'exp' function.
84
/// Computes the exponential of the given value.
85
/// ```sql
86
/// SELECT EXP(column_1) FROM df;
87
/// ```
88
Exp,
89
/// SQL 'floor' function.
90
/// Returns the nearest integer away from zero.
91
/// 0.5 will be rounded
92
/// ```sql
93
/// SELECT FLOOR(column_1) FROM df;
94
/// ```
95
Floor,
96
/// SQL 'pi' function.
97
/// Returns a (very good) approximation of 𝜋.
98
/// ```sql
99
/// SELECT PI() FROM df;
100
/// ```
101
Pi,
102
/// SQL 'ln' function.
103
/// Computes the natural logarithm of the given value.
104
/// ```sql
105
/// SELECT LN(column_1) FROM df;
106
/// ```
107
Ln,
108
/// SQL 'log2' function.
109
/// Computes the logarithm of the given value in base 2.
110
/// ```sql
111
/// SELECT LOG2(column_1) FROM df;
112
/// ```
113
Log2,
114
/// SQL 'log10' function.
115
/// Computes the logarithm of the given value in base 10.
116
/// ```sql
117
/// SELECT LOG10(column_1) FROM df;
118
/// ```
119
Log10,
120
/// SQL 'log' function.
121
/// Computes the `base` logarithm of the given value.
122
/// ```sql
123
/// SELECT LOG(column_1, 10) FROM df;
124
/// ```
125
Log,
126
/// SQL 'log1p' function.
127
/// Computes the natural logarithm of "given value plus one".
128
/// ```sql
129
/// SELECT LOG1P(column_1) FROM df;
130
/// ```
131
Log1p,
132
/// SQL 'pow' function.
133
/// Returns the value to the power of the given exponent.
134
/// ```sql
135
/// SELECT POW(column_1, 2) FROM df;
136
/// ```
137
Pow,
138
/// SQL 'mod' function.
139
/// Returns the remainder of a numeric expression divided by another numeric expression.
140
/// ```sql
141
/// SELECT MOD(column_1, 2) FROM df;
142
/// ```
143
Mod,
144
/// SQL 'sqrt' function.
145
/// Returns the square root (√) of a number.
146
/// ```sql
147
/// SELECT SQRT(column_1) FROM df;
148
/// ```
149
Sqrt,
150
/// SQL 'cbrt' function.
151
/// Returns the cube root (∛) of a number.
152
/// ```sql
153
/// SELECT CBRT(column_1) FROM df;
154
/// ```
155
Cbrt,
156
/// SQL 'round' function.
157
/// Round a number to `x` decimals (default: 0) away from zero.
158
/// .5 is rounded away from zero.
159
/// ```sql
160
/// SELECT ROUND(column_1, 3) FROM df;
161
/// ```
162
Round,
163
/// SQL 'sign' function.
164
/// Returns the sign of the argument as -1, 0, or +1.
165
/// ```sql
166
/// SELECT SIGN(column_1) FROM df;
167
/// ```
168
Sign,
169
170
// ----
171
// Trig functions
172
// ----
173
/// SQL 'cos' function.
174
/// Compute the cosine sine of the input expression (in radians).
175
/// ```sql
176
/// SELECT COS(column_1) FROM df;
177
/// ```
178
Cos,
179
/// SQL 'cot' function.
180
/// Compute the cotangent of the input expression (in radians).
181
/// ```sql
182
/// SELECT COT(column_1) FROM df;
183
/// ```
184
Cot,
185
/// SQL 'sin' function.
186
/// Compute the sine of the input expression (in radians).
187
/// ```sql
188
/// SELECT SIN(column_1) FROM df;
189
/// ```
190
Sin,
191
/// SQL 'tan' function.
192
/// Compute the tangent of the input expression (in radians).
193
/// ```sql
194
/// SELECT TAN(column_1) FROM df;
195
/// ```
196
Tan,
197
/// SQL 'cosd' function.
198
/// Compute the cosine sine of the input expression (in degrees).
199
/// ```sql
200
/// SELECT COSD(column_1) FROM df;
201
/// ```
202
CosD,
203
/// SQL 'cotd' function.
204
/// Compute cotangent of the input expression (in degrees).
205
/// ```sql
206
/// SELECT COTD(column_1) FROM df;
207
/// ```
208
CotD,
209
/// SQL 'sind' function.
210
/// Compute the sine of the input expression (in degrees).
211
/// ```sql
212
/// SELECT SIND(column_1) FROM df;
213
/// ```
214
SinD,
215
/// SQL 'tand' function.
216
/// Compute the tangent of the input expression (in degrees).
217
/// ```sql
218
/// SELECT TAND(column_1) FROM df;
219
/// ```
220
TanD,
221
/// SQL 'acos' function.
222
/// Compute inverse cosine of the input expression (in radians).
223
/// ```sql
224
/// SELECT ACOS(column_1) FROM df;
225
/// ```
226
Acos,
227
/// SQL 'asin' function.
228
/// Compute inverse sine of the input expression (in radians).
229
/// ```sql
230
/// SELECT ASIN(column_1) FROM df;
231
/// ```
232
Asin,
233
/// SQL 'atan' function.
234
/// Compute inverse tangent of the input expression (in radians).
235
/// ```sql
236
/// SELECT ATAN(column_1) FROM df;
237
/// ```
238
Atan,
239
/// SQL 'atan2' function.
240
/// Compute the inverse tangent of column_1/column_2 (in radians).
241
/// ```sql
242
/// SELECT ATAN2(column_1, column_2) FROM df;
243
/// ```
244
Atan2,
245
/// SQL 'acosd' function.
246
/// Compute inverse cosine of the input expression (in degrees).
247
/// ```sql
248
/// SELECT ACOSD(column_1) FROM df;
249
/// ```
250
AcosD,
251
/// SQL 'asind' function.
252
/// Compute inverse sine of the input expression (in degrees).
253
/// ```sql
254
/// SELECT ASIND(column_1) FROM df;
255
/// ```
256
AsinD,
257
/// SQL 'atand' function.
258
/// Compute inverse tangent of the input expression (in degrees).
259
/// ```sql
260
/// SELECT ATAND(column_1) FROM df;
261
/// ```
262
AtanD,
263
/// SQL 'atan2d' function.
264
/// Compute the inverse tangent of column_1/column_2 (in degrees).
265
/// ```sql
266
/// SELECT ATAN2D(column_1) FROM df;
267
/// ```
268
Atan2D,
269
/// SQL 'degrees' function.
270
/// Convert between radians and degrees.
271
/// ```sql
272
/// SELECT DEGREES(column_1) FROM df;
273
/// ```
274
///
275
///
276
Degrees,
277
/// SQL 'RADIANS' function.
278
/// Convert between degrees and radians.
279
/// ```sql
280
/// SELECT RADIANS(column_1) FROM df;
281
/// ```
282
Radians,
283
284
// ----
285
// Temporal functions
286
// ----
287
/// SQL 'date_part' function.
288
/// Extracts a part of a date (or datetime) such as 'year', 'month', etc.
289
/// ```sql
290
/// SELECT DATE_PART('year', column_1) FROM df;
291
/// SELECT DATE_PART('day', column_1) FROM df;
292
DatePart,
293
/// SQL 'strftime' function.
294
/// Converts a datetime to a string using a format string.
295
/// ```sql
296
/// SELECT STRFTIME(column_1, '%d-%m-%Y %H:%M') FROM df;
297
/// ```
298
Strftime,
299
300
// ----
301
// String functions
302
// ----
303
/// SQL 'bit_length' function (bytes).
304
/// ```sql
305
/// SELECT BIT_LENGTH(column_1) FROM df;
306
/// ```
307
BitLength,
308
/// SQL 'concat' function.
309
/// Returns all input expressions concatenated together as a string.
310
/// ```sql
311
/// SELECT CONCAT(column_1, column_2) FROM df;
312
/// ```
313
Concat,
314
/// SQL 'concat_ws' function.
315
/// Returns all input expressions concatenated together
316
/// (and interleaved with a separator) as a string.
317
/// ```sql
318
/// SELECT CONCAT_WS(':', column_1, column_2, column_3) FROM df;
319
/// ```
320
ConcatWS,
321
/// SQL 'date' function.
322
/// Converts a formatted string date to an actual Date type; ISO-8601 format is assumed
323
/// unless a strftime-compatible formatting string is provided as the second parameter.
324
/// ```sql
325
/// SELECT DATE('2021-03-15') FROM df;
326
/// SELECT DATE('2021-15-03', '%Y-d%-%m') FROM df;
327
/// SELECT DATE('2021-03', '%Y-%m') FROM df;
328
/// ```
329
Date,
330
/// SQL 'ends_with' function.
331
/// Returns True if the value ends with the second argument.
332
/// ```sql
333
/// SELECT ENDS_WITH(column_1, 'a') FROM df;
334
/// SELECT column_2 from df WHERE ENDS_WITH(column_1, 'a');
335
/// ```
336
EndsWith,
337
/// SQL 'initcap' function.
338
/// Returns the value with the first letter capitalized.
339
/// ```sql
340
/// SELECT INITCAP(column_1) FROM df;
341
/// ```
342
#[cfg(feature = "nightly")]
343
InitCap,
344
/// SQL 'left' function.
345
/// Returns the first (leftmost) `n` characters.
346
/// ```sql
347
/// SELECT LEFT(column_1, 3) FROM df;
348
/// ```
349
Left,
350
/// SQL 'length' function (characters.
351
/// Returns the character length of the string.
352
/// ```sql
353
/// SELECT LENGTH(column_1) FROM df;
354
/// ```
355
Length,
356
/// SQL 'lower' function.
357
/// Returns an lowercased column.
358
/// ```sql
359
/// SELECT LOWER(column_1) FROM df;
360
/// ```
361
Lower,
362
/// SQL 'ltrim' function.
363
/// Strip whitespaces from the left.
364
/// ```sql
365
/// SELECT LTRIM(column_1) FROM df;
366
/// ```
367
LTrim,
368
/// SQL 'normalize' function.
369
/// Convert string to Unicode normalization form
370
/// (one of NFC, NFKC, NFD, or NFKD - unquoted).
371
/// ```sql
372
/// SELECT NORMALIZE(column_1, NFC) FROM df;
373
/// ```
374
Normalize,
375
/// SQL 'octet_length' function.
376
/// Returns the length of a given string in bytes.
377
/// ```sql
378
/// SELECT OCTET_LENGTH(column_1) FROM df;
379
/// ```
380
OctetLength,
381
/// SQL 'regexp_like' function.
382
/// True if `pattern` matches the value (optional: `flags`).
383
/// ```sql
384
/// SELECT REGEXP_LIKE(column_1, 'xyz', 'i') FROM df;
385
/// ```
386
RegexpLike,
387
/// SQL 'replace' function.
388
/// Replace a given substring with another string.
389
/// ```sql
390
/// SELECT REPLACE(column_1, 'old', 'new') FROM df;
391
/// ```
392
Replace,
393
/// SQL 'reverse' function.
394
/// Return the reversed string.
395
/// ```sql
396
/// SELECT REVERSE(column_1) FROM df;
397
/// ```
398
Reverse,
399
/// SQL 'right' function.
400
/// Returns the last (rightmost) `n` characters.
401
/// ```sql
402
/// SELECT RIGHT(column_1, 3) FROM df;
403
/// ```
404
Right,
405
/// SQL 'rtrim' function.
406
/// Strip whitespaces from the right.
407
/// ```sql
408
/// SELECT RTRIM(column_1) FROM df;
409
/// ```
410
RTrim,
411
/// SQL 'split_part' function.
412
/// Splits a string into an array of strings using the given delimiter
413
/// and returns the `n`-th part (1-indexed).
414
/// ```sql
415
/// SELECT SPLIT_PART(column_1, ',', 2) FROM df;
416
/// ```
417
SplitPart,
418
/// SQL 'starts_with' function.
419
/// Returns True if the value starts with the second argument.
420
/// ```sql
421
/// SELECT STARTS_WITH(column_1, 'a') FROM df;
422
/// SELECT column_2 from df WHERE STARTS_WITH(column_1, 'a');
423
/// ```
424
StartsWith,
425
/// SQL 'strpos' function.
426
/// Returns the index of the given substring in the target string.
427
/// ```sql
428
/// SELECT STRPOS(column_1,'xyz') FROM df;
429
/// ```
430
StrPos,
431
/// SQL 'substr' function.
432
/// Returns a portion of the data (first character = 1) in the range.
433
/// \[start, start + length]
434
/// ```sql
435
/// SELECT SUBSTR(column_1, 3, 5) FROM df;
436
/// ```
437
Substring,
438
/// SQL 'string_to_array' function.
439
/// Splits a string into an array of strings using the given delimiter.
440
/// ```sql
441
/// SELECT STRING_TO_ARRAY(column_1, ',') FROM df;
442
/// ```
443
StringToArray,
444
/// SQL 'strptime' function.
445
/// Converts a string to a datetime using a format string.
446
/// ```sql
447
/// SELECT STRPTIME(column_1, '%d-%m-%Y %H:%M') FROM df;
448
/// ```
449
Strptime,
450
/// SQL 'time' function.
451
/// Converts a formatted string time to an actual Time type; ISO-8601 format is
452
/// assumed unless a strftime-compatible formatting string is provided as the second
453
/// parameter.
454
/// ```sql
455
/// SELECT TIME('10:30:45') FROM df;
456
/// SELECT TIME('20.30', '%H.%M') FROM df;
457
/// ```
458
Time,
459
/// SQL 'timestamp' function.
460
/// Converts a formatted string datetime to an actual Datetime type; ISO-8601 format is
461
/// assumed unless a strftime-compatible formatting string is provided as the second
462
/// parameter.
463
/// ```sql
464
/// SELECT TIMESTAMP('2021-03-15 10:30:45') FROM df;
465
/// SELECT TIMESTAMP('2021-15-03T00:01:02.333', '%Y-d%-%m %H:%M:%S') FROM df;
466
/// ```
467
Timestamp,
468
/// SQL 'upper' function.
469
/// Returns an uppercased column.
470
/// ```sql
471
/// SELECT UPPER(column_1) FROM df;
472
/// ```
473
Upper,
474
475
// ----
476
// Conditional functions
477
// ----
478
/// SQL 'coalesce' function.
479
/// Returns the first non-null value in the provided values/columns.
480
/// ```sql
481
/// SELECT COALESCE(column_1, ...) FROM df;
482
/// ```
483
Coalesce,
484
/// SQL 'greatest' function.
485
/// Returns the greatest value in the list of expressions.
486
/// ```sql
487
/// SELECT GREATEST(column_1, column_2, ...) FROM df;
488
/// ```
489
Greatest,
490
/// SQL 'if' function.
491
/// Returns expr1 if the boolean condition provided as the first
492
/// parameter evaluates to true, and expr2 otherwise.
493
/// ```sql
494
/// SELECT IF(column < 0, expr1, expr2) FROM df;
495
/// ```
496
If,
497
/// SQL 'ifnull' function.
498
/// If an expression value is NULL, return an alternative value.
499
/// ```sql
500
/// SELECT IFNULL(string_col, 'n/a') FROM df;
501
/// ```
502
IfNull,
503
/// SQL 'least' function.
504
/// Returns the smallest value in the list of expressions.
505
/// ```sql
506
/// SELECT LEAST(column_1, column_2, ...) FROM df;
507
/// ```
508
Least,
509
/// SQL 'nullif' function.
510
/// Returns NULL if two expressions are equal, otherwise returns the first.
511
/// ```sql
512
/// SELECT NULLIF(column_1, column_2) FROM df;
513
/// ```
514
NullIf,
515
516
// ----
517
// Aggregate functions
518
// ----
519
/// SQL 'avg' function.
520
/// Returns the average (mean) of all the elements in the grouping.
521
/// ```sql
522
/// SELECT AVG(column_1) FROM df;
523
/// ```
524
Avg,
525
/// SQL 'corr' function.
526
/// Returns the Pearson correlation coefficient between two columns.
527
/// ```sql
528
/// SELECT CORR(column_1, column_2) FROM df;
529
/// ```
530
Corr,
531
/// SQL 'count' function.
532
/// Returns the amount of elements in the grouping.
533
/// ```sql
534
/// SELECT COUNT(column_1) FROM df;
535
/// SELECT COUNT(*) FROM df;
536
/// SELECT COUNT(DISTINCT column_1) FROM df;
537
/// SELECT COUNT(DISTINCT *) FROM df;
538
/// ```
539
Count,
540
/// SQL 'covar_pop' function.
541
/// Returns the population covariance between two columns.
542
/// ```sql
543
/// SELECT COVAR_POP(column_1, column_2) FROM df;
544
/// ```
545
CovarPop,
546
/// SQL 'covar_samp' function.
547
/// Returns the sample covariance between two columns.
548
/// ```sql
549
/// SELECT COVAR_SAMP(column_1, column_2) FROM df;
550
/// ```
551
CovarSamp,
552
/// SQL 'first' function.
553
/// Returns the first element of the grouping.
554
/// ```sql
555
/// SELECT FIRST(column_1) FROM df;
556
/// ```
557
First,
558
/// SQL 'last' function.
559
/// Returns the last element of the grouping.
560
/// ```sql
561
/// SELECT LAST(column_1) FROM df;
562
/// ```
563
Last,
564
/// SQL 'max' function.
565
/// Returns the greatest (maximum) of all the elements in the grouping.
566
/// ```sql
567
/// SELECT MAX(column_1) FROM df;
568
/// ```
569
Max,
570
/// SQL 'median' function.
571
/// Returns the median element from the grouping.
572
/// ```sql
573
/// SELECT MEDIAN(column_1) FROM df;
574
/// ```
575
Median,
576
/// SQL 'quantile_cont' function.
577
/// Returns the continuous quantile element from the grouping
578
/// (interpolated value between two closest values).
579
/// ```sql
580
/// SELECT QUANTILE_CONT(column_1) FROM df;
581
/// ```
582
QuantileCont,
583
/// SQL 'quantile_disc' function.
584
/// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value,
585
/// and returns the value associated with the subinterval where the quantile value falls.
586
/// ```sql
587
/// SELECT QUANTILE_DISC(column_1) FROM df;
588
/// ```
589
QuantileDisc,
590
/// SQL 'min' function.
591
/// Returns the smallest (minimum) of all the elements in the grouping.
592
/// ```sql
593
/// SELECT MIN(column_1) FROM df;
594
/// ```
595
Min,
596
/// SQL 'stddev' function.
597
/// Returns the standard deviation of all the elements in the grouping.
598
/// ```sql
599
/// SELECT STDDEV(column_1) FROM df;
600
/// ```
601
StdDev,
602
/// SQL 'sum' function.
603
/// Returns the sum of all the elements in the grouping.
604
/// ```sql
605
/// SELECT SUM(column_1) FROM df;
606
/// ```
607
Sum,
608
/// SQL 'variance' function.
609
/// Returns the variance of all the elements in the grouping.
610
/// ```sql
611
/// SELECT VARIANCE(column_1) FROM df;
612
/// ```
613
Variance,
614
// ----
615
// Array functions
616
// ----
617
/// SQL 'array_length' function.
618
/// Returns the length of the array.
619
/// ```sql
620
/// SELECT ARRAY_LENGTH(column_1) FROM df;
621
/// ```
622
ArrayLength,
623
/// SQL 'array_lower' function.
624
/// Returns the minimum value in an array; equivalent to `array_min`.
625
/// ```sql
626
/// SELECT ARRAY_LOWER(column_1) FROM df;
627
/// ```
628
ArrayMin,
629
/// SQL 'array_upper' function.
630
/// Returns the maximum value in an array; equivalent to `array_max`.
631
/// ```sql
632
/// SELECT ARRAY_UPPER(column_1) FROM df;
633
/// ```
634
ArrayMax,
635
/// SQL 'array_sum' function.
636
/// Returns the sum of all values in an array.
637
/// ```sql
638
/// SELECT ARRAY_SUM(column_1) FROM df;
639
/// ```
640
ArraySum,
641
/// SQL 'array_mean' function.
642
/// Returns the mean of all values in an array.
643
/// ```sql
644
/// SELECT ARRAY_MEAN(column_1) FROM df;
645
/// ```
646
ArrayMean,
647
/// SQL 'array_reverse' function.
648
/// Returns the array with the elements in reverse order.
649
/// ```sql
650
/// SELECT ARRAY_REVERSE(column_1) FROM df;
651
/// ```
652
ArrayReverse,
653
/// SQL 'array_unique' function.
654
/// Returns the array with the unique elements.
655
/// ```sql
656
/// SELECT ARRAY_UNIQUE(column_1) FROM df;
657
/// ```
658
ArrayUnique,
659
/// SQL 'unnest' function.
660
/// Unnest/explodes an array column into multiple rows.
661
/// ```sql
662
/// SELECT unnest(column_1) FROM df;
663
/// ```
664
Explode,
665
/// SQL 'array_agg' function.
666
/// Concatenates the input expressions, including nulls, into an array.
667
/// ```sql
668
/// SELECT ARRAY_AGG(column_1, column_2, ...) FROM df;
669
/// ```
670
ArrayAgg,
671
/// SQL 'array_to_string' function.
672
/// Takes all elements of the array and joins them into one string.
673
/// ```sql
674
/// SELECT ARRAY_TO_STRING(column_1, ',') FROM df;
675
/// SELECT ARRAY_TO_STRING(column_1, ',', 'n/a') FROM df;
676
/// ```
677
ArrayToString,
678
/// SQL 'array_get' function.
679
/// Returns the value at the given index in the array.
680
/// ```sql
681
/// SELECT ARRAY_GET(column_1, 1) FROM df;
682
/// ```
683
ArrayGet,
684
/// SQL 'array_contains' function.
685
/// Returns true if the array contains the value.
686
/// ```sql
687
/// SELECT ARRAY_CONTAINS(column_1, 'foo') FROM df;
688
/// ```
689
ArrayContains,
690
691
// ----
692
// Column selection
693
// ----
694
Columns,
695
696
// ----
697
// User-defined
698
// ----
699
Udf(String),
700
}
701
702
impl PolarsSQLFunctions {
703
pub(crate) fn keywords() -> &'static [&'static str] {
704
&[
705
"abs",
706
"acos",
707
"acosd",
708
"array_contains",
709
"array_get",
710
"array_length",
711
"array_lower",
712
"array_mean",
713
"array_reverse",
714
"array_sum",
715
"array_to_string",
716
"array_unique",
717
"array_upper",
718
"asin",
719
"asind",
720
"atan",
721
"atan2",
722
"atan2d",
723
"atand",
724
"avg",
725
"bit_and",
726
"bit_count",
727
"bit_length",
728
"bit_or",
729
"bit_xor",
730
"cbrt",
731
"ceil",
732
"ceiling",
733
"char_length",
734
"character_length",
735
"coalesce",
736
"columns",
737
"concat",
738
"concat_ws",
739
"corr",
740
"cos",
741
"cosd",
742
"cot",
743
"cotd",
744
"count",
745
"covar",
746
"covar_pop",
747
"covar_samp",
748
"date",
749
"date_part",
750
"degrees",
751
"ends_with",
752
"exp",
753
"first",
754
"floor",
755
"greatest",
756
"if",
757
"ifnull",
758
"initcap",
759
"last",
760
"least",
761
"left",
762
"length",
763
"ln",
764
"log",
765
"log10",
766
"log1p",
767
"log2",
768
"lower",
769
"ltrim",
770
"max",
771
"median",
772
"quantile_disc",
773
"min",
774
"mod",
775
"nullif",
776
"octet_length",
777
"pi",
778
"pow",
779
"power",
780
"quantile_cont",
781
"quantile_disc",
782
"radians",
783
"regexp_like",
784
"replace",
785
"reverse",
786
"right",
787
"round",
788
"rtrim",
789
"sign",
790
"sin",
791
"sind",
792
"sqrt",
793
"starts_with",
794
"stddev",
795
"stddev_samp",
796
"stdev",
797
"stdev_samp",
798
"strftime",
799
"strpos",
800
"strptime",
801
"substr",
802
"sum",
803
"tan",
804
"tand",
805
"unnest",
806
"upper",
807
"var",
808
"var_samp",
809
"variance",
810
]
811
}
812
}
813
814
impl PolarsSQLFunctions {
815
fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
816
let function_name = function.name.0[0].value.to_lowercase();
817
Ok(match function_name.as_str() {
818
// ----
819
// Bitwise functions
820
// ----
821
"bit_and" | "bitand" => Self::BitAnd,
822
#[cfg(feature = "bitwise")]
823
"bit_count" | "bitcount" => Self::BitCount,
824
"bit_or" | "bitor" => Self::BitOr,
825
"bit_xor" | "bitxor" | "xor" => Self::BitXor,
826
827
// ----
828
// Math functions
829
// ----
830
"abs" => Self::Abs,
831
"cbrt" => Self::Cbrt,
832
"ceil" | "ceiling" => Self::Ceil,
833
"div" => Self::Div,
834
"exp" => Self::Exp,
835
"floor" => Self::Floor,
836
"ln" => Self::Ln,
837
"log" => Self::Log,
838
"log10" => Self::Log10,
839
"log1p" => Self::Log1p,
840
"log2" => Self::Log2,
841
"mod" => Self::Mod,
842
"pi" => Self::Pi,
843
"pow" | "power" => Self::Pow,
844
"round" => Self::Round,
845
"sign" => Self::Sign,
846
"sqrt" => Self::Sqrt,
847
848
// ----
849
// Trig functions
850
// ----
851
"cos" => Self::Cos,
852
"cot" => Self::Cot,
853
"sin" => Self::Sin,
854
"tan" => Self::Tan,
855
"cosd" => Self::CosD,
856
"cotd" => Self::CotD,
857
"sind" => Self::SinD,
858
"tand" => Self::TanD,
859
"acos" => Self::Acos,
860
"asin" => Self::Asin,
861
"atan" => Self::Atan,
862
"atan2" => Self::Atan2,
863
"acosd" => Self::AcosD,
864
"asind" => Self::AsinD,
865
"atand" => Self::AtanD,
866
"atan2d" => Self::Atan2D,
867
"degrees" => Self::Degrees,
868
"radians" => Self::Radians,
869
870
// ----
871
// Conditional functions
872
// ----
873
"coalesce" => Self::Coalesce,
874
"greatest" => Self::Greatest,
875
"if" => Self::If,
876
"ifnull" => Self::IfNull,
877
"least" => Self::Least,
878
"nullif" => Self::NullIf,
879
880
// ----
881
// Date functions
882
// ----
883
"date_part" => Self::DatePart,
884
"strftime" => Self::Strftime,
885
886
// ----
887
// String functions
888
// ----
889
"bit_length" => Self::BitLength,
890
"concat" => Self::Concat,
891
"concat_ws" => Self::ConcatWS,
892
"date" => Self::Date,
893
"timestamp" | "datetime" => Self::Timestamp,
894
"ends_with" => Self::EndsWith,
895
#[cfg(feature = "nightly")]
896
"initcap" => Self::InitCap,
897
"length" | "char_length" | "character_length" => Self::Length,
898
"left" => Self::Left,
899
"lower" => Self::Lower,
900
"ltrim" => Self::LTrim,
901
"normalize" => Self::Normalize,
902
"octet_length" => Self::OctetLength,
903
"strpos" => Self::StrPos,
904
"regexp_like" => Self::RegexpLike,
905
"replace" => Self::Replace,
906
"reverse" => Self::Reverse,
907
"right" => Self::Right,
908
"rtrim" => Self::RTrim,
909
"split_part" => Self::SplitPart,
910
"starts_with" => Self::StartsWith,
911
"string_to_array" => Self::StringToArray,
912
"strptime" => Self::Strptime,
913
"substr" => Self::Substring,
914
"time" => Self::Time,
915
"upper" => Self::Upper,
916
917
// ----
918
// Aggregate functions
919
// ----
920
"avg" => Self::Avg,
921
"corr" => Self::Corr,
922
"count" => Self::Count,
923
"covar_pop" => Self::CovarPop,
924
"covar" | "covar_samp" => Self::CovarSamp,
925
"first" => Self::First,
926
"last" => Self::Last,
927
"max" => Self::Max,
928
"median" => Self::Median,
929
"quantile_cont" => Self::QuantileCont,
930
"quantile_disc" => Self::QuantileDisc,
931
"min" => Self::Min,
932
"stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
933
"sum" => Self::Sum,
934
"var" | "variance" | "var_samp" => Self::Variance,
935
936
// ----
937
// Array functions
938
// ----
939
"array_agg" => Self::ArrayAgg,
940
"array_contains" => Self::ArrayContains,
941
"array_get" => Self::ArrayGet,
942
"array_length" => Self::ArrayLength,
943
"array_lower" => Self::ArrayMin,
944
"array_mean" => Self::ArrayMean,
945
"array_reverse" => Self::ArrayReverse,
946
"array_sum" => Self::ArraySum,
947
"array_to_string" => Self::ArrayToString,
948
"array_unique" => Self::ArrayUnique,
949
"array_upper" => Self::ArrayMax,
950
"unnest" => Self::Explode,
951
952
// ----
953
// Column selection
954
// ----
955
"columns" => Self::Columns,
956
957
other => {
958
if ctx.function_registry.contains(other) {
959
Self::Udf(other.to_string())
960
} else {
961
polars_bail!(SQLInterface: "unsupported function '{}'", other);
962
}
963
},
964
})
965
}
966
}
967
968
impl SQLFunctionVisitor<'_> {
969
pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
970
use PolarsSQLFunctions::*;
971
use polars_lazy::prelude::Literal;
972
973
let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
974
let function = self.func;
975
976
// TODO: implement the following functions where possible
977
if !function.within_group.is_empty() {
978
polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
979
}
980
if function.filter.is_some() {
981
polars_bail!(SQLInterface: "'FILTER' is not currently supported")
982
}
983
if function.null_treatment.is_some() {
984
polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
985
}
986
987
let log_with_base =
988
|e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
989
match function_name {
990
// ----
991
// Bitwise functions
992
// ----
993
BitAnd => self.visit_binary::<Expr>(Expr::and),
994
#[cfg(feature = "bitwise")]
995
BitCount => self.visit_unary(Expr::bitwise_count_ones),
996
BitOr => self.visit_binary::<Expr>(Expr::or),
997
BitXor => self.visit_binary::<Expr>(Expr::xor),
998
999
// ----
1000
// Math functions
1001
// ----
1002
Abs => self.visit_unary(Expr::abs),
1003
Cbrt => self.visit_unary(Expr::cbrt),
1004
Ceil => self.visit_unary(Expr::ceil),
1005
Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
1006
Exp => self.visit_unary(Expr::exp),
1007
Floor => self.visit_unary(Expr::floor),
1008
Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
1009
Log => self.visit_binary(Expr::log),
1010
Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
1011
Log1p => self.visit_unary(Expr::log1p),
1012
Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
1013
Pi => self.visit_nullary(Expr::pi),
1014
Mod => self.visit_binary(|e1, e2| e1 % e2),
1015
Pow => self.visit_binary::<Expr>(Expr::pow),
1016
Round => {
1017
let args = extract_args(function)?;
1018
match args.len() {
1019
1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
1020
2 => self.try_visit_binary(|e, decimals| {
1021
Ok(e.round(match decimals {
1022
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1023
if n >= 0 { n as u32 } else {
1024
polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", args[1])
1025
}
1026
},
1027
_ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
1028
}, RoundMode::default()))
1029
}),
1030
_ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
1031
}
1032
},
1033
Sign => self.visit_unary(Expr::sign),
1034
Sqrt => self.visit_unary(Expr::sqrt),
1035
1036
// ----
1037
// Trig functions
1038
// ----
1039
Acos => self.visit_unary(Expr::arccos),
1040
AcosD => self.visit_unary(|e| e.arccos().degrees()),
1041
Asin => self.visit_unary(Expr::arcsin),
1042
AsinD => self.visit_unary(|e| e.arcsin().degrees()),
1043
Atan => self.visit_unary(Expr::arctan),
1044
Atan2 => self.visit_binary(Expr::arctan2),
1045
Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
1046
AtanD => self.visit_unary(|e| e.arctan().degrees()),
1047
Cos => self.visit_unary(Expr::cos),
1048
CosD => self.visit_unary(|e| e.radians().cos()),
1049
Cot => self.visit_unary(Expr::cot),
1050
CotD => self.visit_unary(|e| e.radians().cot()),
1051
Degrees => self.visit_unary(Expr::degrees),
1052
Radians => self.visit_unary(Expr::radians),
1053
Sin => self.visit_unary(Expr::sin),
1054
SinD => self.visit_unary(|e| e.radians().sin()),
1055
Tan => self.visit_unary(Expr::tan),
1056
TanD => self.visit_unary(|e| e.radians().tan()),
1057
1058
// ----
1059
// Conditional functions
1060
// ----
1061
Coalesce => self.visit_variadic(coalesce),
1062
Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
1063
If => {
1064
let args = extract_args(function)?;
1065
match args.len() {
1066
3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
1067
Ok(when(cond).then(expr1).otherwise(expr2))
1068
}),
1069
_ => {
1070
polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
1071
)
1072
},
1073
}
1074
},
1075
IfNull => {
1076
let args = extract_args(function)?;
1077
match args.len() {
1078
2 => self.visit_variadic(coalesce),
1079
_ => {
1080
polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
1081
},
1082
}
1083
},
1084
Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
1085
NullIf => {
1086
let args = extract_args(function)?;
1087
match args.len() {
1088
2 => self.visit_binary(|l: Expr, r: Expr| {
1089
when(l.clone().eq(r))
1090
.then(lit(LiteralValue::untyped_null()))
1091
.otherwise(l)
1092
}),
1093
_ => {
1094
polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
1095
},
1096
}
1097
},
1098
1099
// ----
1100
// Date functions
1101
// ----
1102
DatePart => self.try_visit_binary(|part, e| {
1103
match part {
1104
Expr::Literal(p) if p.extract_str().is_some() => {
1105
let p = p.extract_str().unwrap();
1106
// note: 'DATE_PART' and 'EXTRACT' are minor syntactic
1107
// variations on otherwise identical functionality
1108
parse_extract_date_part(
1109
e,
1110
&DateTimeField::Custom(Ident {
1111
value: p.to_string(),
1112
quote_style: None,
1113
span: Span::empty(),
1114
}),
1115
)
1116
},
1117
_ => {
1118
polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
1119
},
1120
}
1121
}),
1122
Strftime => {
1123
let args = extract_args(function)?;
1124
match args.len() {
1125
2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
1126
_ => {
1127
polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
1128
},
1129
}
1130
},
1131
1132
// ----
1133
// String functions
1134
// ----
1135
BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
1136
Concat => {
1137
let args = extract_args(function)?;
1138
if args.is_empty() {
1139
polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
1140
} else {
1141
self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
1142
}
1143
},
1144
ConcatWS => {
1145
let args = extract_args(function)?;
1146
if args.len() < 2 {
1147
polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
1148
} else {
1149
self.try_visit_variadic(|exprs: &[Expr]| {
1150
match &exprs[0] {
1151
Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
1152
_ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
1153
}
1154
})
1155
}
1156
},
1157
Date => {
1158
let args = extract_args(function)?;
1159
match args.len() {
1160
1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
1161
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
1162
_ => {
1163
polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
1164
},
1165
}
1166
},
1167
EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
1168
#[cfg(feature = "nightly")]
1169
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
1170
Left => self.try_visit_binary(|e, length| {
1171
Ok(match length {
1172
Expr::Literal(lv) if lv.is_null() => lit(lv),
1173
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
1174
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1175
let len = if n > 0 {
1176
lit(n)
1177
} else {
1178
(e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
1179
};
1180
e.str().slice(lit(0), len)
1181
},
1182
Expr::Literal(v) => {
1183
polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
1184
},
1185
_ => when(length.clone().gt_eq(lit(0)))
1186
.then(e.clone().str().slice(lit(0), length.clone().abs()))
1187
.otherwise(e.clone().str().slice(
1188
lit(0),
1189
(e.str().len_chars() + length.clone()).clip_min(lit(0)),
1190
)),
1191
})
1192
}),
1193
Length => self.visit_unary(|e| e.str().len_chars()),
1194
Lower => self.visit_unary(|e| e.str().to_lowercase()),
1195
LTrim => {
1196
let args = extract_args(function)?;
1197
match args.len() {
1198
1 => self.visit_unary(|e| {
1199
e.str().strip_chars_start(lit(LiteralValue::untyped_null()))
1200
}),
1201
2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)),
1202
_ => {
1203
polars_bail!(SQLSyntax: "LTRIM expects 1-2 arguments (found {})", args.len())
1204
},
1205
}
1206
},
1207
Normalize => {
1208
let args = extract_args(function)?;
1209
match args.len() {
1210
1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
1211
2 => {
1212
let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
1213
value: s,
1214
quote_style: None,
1215
span: _,
1216
})) = args[1]
1217
{
1218
match s.to_uppercase().as_str() {
1219
"NFC" => UnicodeForm::NFC,
1220
"NFD" => UnicodeForm::NFD,
1221
"NFKC" => UnicodeForm::NFKC,
1222
"NFKD" => UnicodeForm::NFKD,
1223
_ => {
1224
polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
1225
},
1226
}
1227
} else {
1228
polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
1229
};
1230
self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
1231
},
1232
_ => {
1233
polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
1234
},
1235
}
1236
},
1237
OctetLength => self.visit_unary(|e| e.str().len_bytes()),
1238
StrPos => {
1239
// note: SQL is 1-indexed; returns zero if no match found
1240
self.visit_binary(|expr, substring| {
1241
(expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
1242
})
1243
},
1244
RegexpLike => {
1245
let args = extract_args(function)?;
1246
match args.len() {
1247
2 => self.visit_binary(|e, s| e.str().contains(s, true)),
1248
3 => self.try_visit_ternary(|e, pat, flags| {
1249
Ok(e.str().contains(
1250
match (pat, flags) {
1251
(Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
1252
let s = s_lv.extract_str().unwrap();
1253
let f = f_lv.extract_str().unwrap();
1254
if f.is_empty() {
1255
polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
1256
};
1257
lit(format!("(?{f}){s}"))
1258
},
1259
_ => {
1260
polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
1261
},
1262
},
1263
true))
1264
}),
1265
_ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
1266
}
1267
},
1268
Replace => {
1269
let args = extract_args(function)?;
1270
match args.len() {
1271
3 => self
1272
.try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
1273
_ => {
1274
polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
1275
},
1276
}
1277
},
1278
Reverse => self.visit_unary(|e| e.str().reverse()),
1279
Right => self.try_visit_binary(|e, length| {
1280
Ok(match length {
1281
Expr::Literal(lv) if lv.is_null() => lit(lv),
1282
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
1283
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1284
let n: i64 = n.try_into().unwrap();
1285
let offset = if n < 0 {
1286
lit(n.abs())
1287
} else {
1288
e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
1289
};
1290
e.str().slice(offset, lit(LiteralValue::untyped_null()))
1291
},
1292
Expr::Literal(v) => {
1293
polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
1294
},
1295
_ => when(length.clone().lt(lit(0)))
1296
.then(
1297
e.clone()
1298
.str()
1299
.slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
1300
)
1301
.otherwise(e.clone().str().slice(
1302
e.str().len_chars().cast(DataType::Int32) - length.clone(),
1303
lit(LiteralValue::untyped_null()),
1304
)),
1305
})
1306
}),
1307
RTrim => {
1308
let args = extract_args(function)?;
1309
match args.len() {
1310
1 => self.visit_unary(|e| {
1311
e.str().strip_chars_end(lit(LiteralValue::untyped_null()))
1312
}),
1313
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
1314
_ => {
1315
polars_bail!(SQLSyntax: "RTRIM expects 1-2 arguments (found {})", args.len())
1316
},
1317
}
1318
},
1319
SplitPart => {
1320
let args = extract_args(function)?;
1321
match args.len() {
1322
3 => self.try_visit_ternary(|e, sep, idx| {
1323
let idx = adjust_one_indexed_param(idx, true);
1324
Ok(when(e.clone().is_not_null())
1325
.then(
1326
e.clone()
1327
.str()
1328
.split(sep)
1329
.list()
1330
.get(idx, true)
1331
.fill_null(lit("")),
1332
)
1333
.otherwise(e))
1334
}),
1335
_ => {
1336
polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
1337
},
1338
}
1339
},
1340
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
1341
StringToArray => {
1342
let args = extract_args(function)?;
1343
match args.len() {
1344
2 => self.visit_binary(|e, sep| e.str().split(sep)),
1345
_ => {
1346
polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
1347
},
1348
}
1349
},
1350
Strptime => {
1351
let args = extract_args(function)?;
1352
match args.len() {
1353
2 => self.visit_binary(|e, fmt: String| {
1354
e.str().strptime(
1355
DataType::Datetime(TimeUnit::Microseconds, None),
1356
StrptimeOptions {
1357
format: Some(fmt.into()),
1358
..Default::default()
1359
},
1360
lit("latest"),
1361
)
1362
}),
1363
_ => {
1364
polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
1365
},
1366
}
1367
},
1368
Time => {
1369
let args = extract_args(function)?;
1370
match args.len() {
1371
1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
1372
2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
1373
_ => {
1374
polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
1375
},
1376
}
1377
},
1378
Timestamp => {
1379
let args = extract_args(function)?;
1380
match args.len() {
1381
1 => self.visit_unary(|e| {
1382
e.str()
1383
.to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
1384
}),
1385
2 => self
1386
.visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
1387
_ => {
1388
polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
1389
},
1390
}
1391
},
1392
Substring => {
1393
let args = extract_args(function)?;
1394
match args.len() {
1395
// note: SQL is 1-indexed, hence the need for adjustments
1396
2 => self.try_visit_binary(|e, start| {
1397
Ok(match start {
1398
Expr::Literal(lv) if lv.is_null() => lit(lv),
1399
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1400
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
1401
Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1402
_ => start.clone() + lit(1),
1403
})
1404
}),
1405
3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
1406
Ok(match (start.clone(), length.clone()) {
1407
(Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1408
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1409
polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
1410
},
1411
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
1412
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
1413
e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
1414
},
1415
(Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1416
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1417
polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
1418
},
1419
_ => {
1420
let adjusted_start = start - lit(1);
1421
when(adjusted_start.clone().lt(lit(0)))
1422
.then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
1423
.otherwise(e.str().slice(adjusted_start, length))
1424
}
1425
})
1426
}),
1427
_ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
1428
}
1429
},
1430
Upper => self.visit_unary(|e| e.str().to_uppercase()),
1431
1432
// ----
1433
// Aggregate functions
1434
// ----
1435
Avg => self.visit_unary(Expr::mean),
1436
Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
1437
Count => self.visit_count(),
1438
CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
1439
CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
1440
First => self.visit_unary(Expr::first),
1441
Last => self.visit_unary(Expr::last),
1442
Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
1443
Median => self.visit_unary(Expr::median),
1444
QuantileCont => {
1445
let args = extract_args(function)?;
1446
match args.len() {
1447
2 => self.try_visit_binary(|e, q| {
1448
let value = match q {
1449
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1450
if (0.0..=1.0).contains(&f) {
1451
Expr::from(f)
1452
} else {
1453
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1454
}
1455
},
1456
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1457
if (0..=1).contains(&n) {
1458
Expr::from(n as f64)
1459
} else {
1460
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1461
}
1462
},
1463
_ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1])
1464
};
1465
Ok(e.quantile(value, QuantileMethod::Linear))
1466
}),
1467
_ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()),
1468
}
1469
},
1470
QuantileDisc => {
1471
let args = extract_args(function)?;
1472
match args.len() {
1473
2 => self.try_visit_binary(|e, q| {
1474
let value = match q {
1475
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1476
if (0.0..=1.0).contains(&f) {
1477
Expr::from(f)
1478
} else {
1479
polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1480
}
1481
},
1482
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1483
if (0..=1).contains(&n) {
1484
Expr::from(n as f64)
1485
} else {
1486
polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1487
}
1488
},
1489
_ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1])
1490
};
1491
Ok(e.quantile(value, QuantileMethod::Equiprobable))
1492
}),
1493
_ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()),
1494
}
1495
},
1496
Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
1497
StdDev => self.visit_unary(|e| e.std(1)),
1498
Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
1499
Variance => self.visit_unary(|e| e.var(1)),
1500
1501
// ----
1502
// Array functions
1503
// ----
1504
ArrayAgg => self.visit_arr_agg(),
1505
ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
1506
ArrayGet => {
1507
// note: SQL is 1-indexed, not 0-indexed
1508
self.visit_binary(|e, idx: Expr| {
1509
let idx = adjust_one_indexed_param(idx, true);
1510
e.list().get(idx, true)
1511
})
1512
},
1513
ArrayLength => self.visit_unary(|e| e.list().len()),
1514
ArrayMax => self.visit_unary(|e| e.list().max()),
1515
ArrayMean => self.visit_unary(|e| e.list().mean()),
1516
ArrayMin => self.visit_unary(|e| e.list().min()),
1517
ArrayReverse => self.visit_unary(|e| e.list().reverse()),
1518
ArraySum => self.visit_unary(|e| e.list().sum()),
1519
ArrayToString => self.visit_arr_to_string(),
1520
ArrayUnique => self.visit_unary(|e| e.list().unique()),
1521
Explode => self.visit_unary(|e| e.explode()),
1522
1523
// ----
1524
// Column selection
1525
// ----
1526
Columns => {
1527
let active_schema = self.active_schema;
1528
self.try_visit_unary(|e: Expr| match e {
1529
Expr::Literal(lv) if lv.extract_str().is_some() => {
1530
let pat = lv.extract_str().unwrap();
1531
if pat == "*" {
1532
polars_bail!(
1533
SQLSyntax: "COLUMNS('*') is not a valid regex; \
1534
did you mean COLUMNS(*)?"
1535
)
1536
};
1537
let pat = match pat {
1538
_ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
1539
_ if pat.starts_with('^') => format!("{pat}.*$"),
1540
_ if pat.ends_with('$') => format!("^.*{pat}"),
1541
_ => format!("^.*{pat}.*$"),
1542
};
1543
if let Some(active_schema) = &active_schema {
1544
let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
1545
let col_names = active_schema
1546
.iter_names()
1547
.filter(|name| rx.is_match(name))
1548
.cloned()
1549
.collect::<Vec<_>>();
1550
1551
Ok(if col_names.len() == 1 {
1552
col(col_names.into_iter().next().unwrap())
1553
} else {
1554
cols(col_names).as_expr()
1555
})
1556
} else {
1557
Ok(col(pat.as_str()))
1558
}
1559
},
1560
Expr::Selector(s) => Ok(s.as_expr()),
1561
_ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
1562
})
1563
},
1564
1565
// ----
1566
// User-defined
1567
// ----
1568
Udf(func_name) => self.visit_udf(&func_name),
1569
}
1570
}
1571
1572
fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
1573
let args = extract_args(self.func)?
1574
.into_iter()
1575
.map(|arg| {
1576
if let FunctionArgExpr::Expr(e) = arg {
1577
parse_sql_expr(e, self.ctx, self.active_schema)
1578
} else {
1579
polars_bail!(SQLInterface: "only expressions are supported in UDFs")
1580
}
1581
})
1582
.collect::<PolarsResult<Vec<_>>>()?;
1583
1584
Ok(self
1585
.ctx
1586
.function_registry
1587
.get_udf(func_name)?
1588
.ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
1589
.call(args))
1590
}
1591
1592
/// Window specs without partition bys are essentially cumulative functions
1593
/// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false)
1594
fn apply_cumulative_window(
1595
&mut self,
1596
f: impl Fn(Expr) -> Expr,
1597
cumulative_f: impl Fn(Expr, bool) -> Expr,
1598
WindowSpec {
1599
partition_by,
1600
order_by,
1601
..
1602
}: &WindowSpec,
1603
) -> PolarsResult<Expr> {
1604
if !order_by.is_empty() && partition_by.is_empty() {
1605
let (order_by, desc): (Vec<Expr>, Vec<bool>) = order_by
1606
.iter()
1607
.map(|o| {
1608
let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
1609
Ok(match o.asc {
1610
Some(b) => (expr, !b),
1611
None => (expr, false),
1612
})
1613
})
1614
.collect::<PolarsResult<Vec<_>>>()?
1615
.into_iter()
1616
.unzip();
1617
self.visit_unary_no_window(|e| {
1618
cumulative_f(
1619
e.sort_by(
1620
&order_by,
1621
SortMultipleOptions::default().with_order_descending_multi(desc.clone()),
1622
),
1623
false,
1624
)
1625
})
1626
} else {
1627
self.visit_unary(f)
1628
}
1629
}
1630
1631
fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1632
self.try_visit_unary(|e| Ok(f(e)))
1633
}
1634
1635
fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
1636
let args = extract_args(self.func)?;
1637
match args.as_slice() {
1638
[FunctionArgExpr::Expr(sql_expr)] => {
1639
f(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?)
1640
},
1641
[FunctionArgExpr::Wildcard] => f(parse_sql_expr(
1642
&SQLExpr::Wildcard(AttachedToken::empty()),
1643
self.ctx,
1644
self.active_schema,
1645
)?),
1646
_ => self.not_supported_error(),
1647
}
1648
.and_then(|e| self.apply_window_spec(e, &self.func.over))
1649
}
1650
1651
/// Some functions have cumulative equivalents that can be applied to window specs
1652
/// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false)
1653
/// visit_unary_with_cumulative_window will take in a function & a cumulative function
1654
/// if there is a cumulative window spec, it will apply the cumulative function,
1655
/// otherwise it will apply the function
1656
fn visit_unary_with_opt_cumulative(
1657
&mut self,
1658
f: impl Fn(Expr) -> Expr,
1659
cumulative_f: impl Fn(Expr, bool) -> Expr,
1660
) -> PolarsResult<Expr> {
1661
match self.func.over.as_ref() {
1662
Some(WindowType::WindowSpec(spec)) => {
1663
self.apply_cumulative_window(f, cumulative_f, spec)
1664
},
1665
Some(WindowType::NamedWindow(named_window)) => polars_bail!(
1666
SQLInterface: "Named windows are not currently supported; found {:?}",
1667
named_window
1668
),
1669
_ => self.visit_unary(f),
1670
}
1671
}
1672
1673
fn visit_unary_no_window(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1674
let args = extract_args(self.func)?;
1675
match args.as_slice() {
1676
[FunctionArgExpr::Expr(sql_expr)] => {
1677
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1678
// apply the function on the inner expr -- e.g. SUM(a) -> SUM
1679
Ok(f(expr))
1680
},
1681
_ => self.not_supported_error(),
1682
}
1683
}
1684
1685
fn visit_binary<Arg: FromSQLExpr>(
1686
&mut self,
1687
f: impl Fn(Expr, Arg) -> Expr,
1688
) -> PolarsResult<Expr> {
1689
self.try_visit_binary(|e, a| Ok(f(e, a)))
1690
}
1691
1692
fn try_visit_binary<Arg: FromSQLExpr>(
1693
&mut self,
1694
f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
1695
) -> PolarsResult<Expr> {
1696
let args = extract_args(self.func)?;
1697
match args.as_slice() {
1698
[
1699
FunctionArgExpr::Expr(sql_expr1),
1700
FunctionArgExpr::Expr(sql_expr2),
1701
] => {
1702
let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1703
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1704
f(expr1, expr2)
1705
},
1706
_ => self.not_supported_error(),
1707
}
1708
}
1709
1710
fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
1711
self.try_visit_variadic(|e| Ok(f(e)))
1712
}
1713
1714
fn try_visit_variadic(
1715
&mut self,
1716
f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
1717
) -> PolarsResult<Expr> {
1718
let args = extract_args(self.func)?;
1719
let mut expr_args = vec![];
1720
for arg in args {
1721
if let FunctionArgExpr::Expr(sql_expr) = arg {
1722
expr_args.push(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?);
1723
} else {
1724
return self.not_supported_error();
1725
};
1726
}
1727
f(&expr_args)
1728
}
1729
1730
fn try_visit_ternary<Arg: FromSQLExpr>(
1731
&mut self,
1732
f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
1733
) -> PolarsResult<Expr> {
1734
let args = extract_args(self.func)?;
1735
match args.as_slice() {
1736
[
1737
FunctionArgExpr::Expr(sql_expr1),
1738
FunctionArgExpr::Expr(sql_expr2),
1739
FunctionArgExpr::Expr(sql_expr3),
1740
] => {
1741
let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1742
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1743
let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
1744
f(expr1, expr2, expr3)
1745
},
1746
_ => self.not_supported_error(),
1747
}
1748
}
1749
1750
fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
1751
let args = extract_args(self.func)?;
1752
if !args.is_empty() {
1753
return self.not_supported_error();
1754
}
1755
Ok(f())
1756
}
1757
1758
fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
1759
let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
1760
match args.as_slice() {
1761
[FunctionArgExpr::Expr(sql_expr)] => {
1762
let mut base = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1763
if is_distinct {
1764
base = base.unique_stable();
1765
}
1766
for clause in clauses {
1767
match clause {
1768
FunctionArgumentClause::OrderBy(order_exprs) => {
1769
base = self.apply_order_by(base, order_exprs.as_slice())?;
1770
},
1771
FunctionArgumentClause::Limit(limit_expr) => {
1772
let limit = parse_sql_expr(&limit_expr, self.ctx, self.active_schema)?;
1773
match limit {
1774
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))
1775
if n >= 0 =>
1776
{
1777
base = base.head(Some(n as usize))
1778
},
1779
_ => {
1780
polars_bail!(SQLSyntax: "LIMIT in ARRAY_AGG must be a positive integer")
1781
},
1782
};
1783
},
1784
_ => {},
1785
}
1786
}
1787
Ok(base.implode())
1788
},
1789
_ => {
1790
polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
1791
},
1792
}
1793
}
1794
1795
fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
1796
let args = extract_args(self.func)?;
1797
match args.len() {
1798
2 => self.try_visit_binary(|e, sep| {
1799
Ok(e.cast(DataType::List(Box::from(DataType::String)))
1800
.list()
1801
.join(sep, true))
1802
}),
1803
#[cfg(feature = "list_eval")]
1804
3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
1805
Expr::Literal(lv) if lv.extract_str().is_some() => {
1806
Ok(if lv.extract_str().unwrap().is_empty() {
1807
e.cast(DataType::List(Box::from(DataType::String)))
1808
.list()
1809
.join(sep, true)
1810
} else {
1811
e.cast(DataType::List(Box::from(DataType::String)))
1812
.list()
1813
.eval(col("").fill_null(lit(lv.extract_str().unwrap())))
1814
.list()
1815
.join(sep, false)
1816
})
1817
},
1818
_ => {
1819
polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
1820
},
1821
}),
1822
_ => {
1823
polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
1824
},
1825
}
1826
}
1827
1828
fn visit_count(&mut self) -> PolarsResult<Expr> {
1829
let (args, is_distinct) = extract_args_distinct(self.func)?;
1830
let count_expr = match (is_distinct, args.as_slice()) {
1831
// count(*), count()
1832
(false, [FunctionArgExpr::Wildcard] | []) => len(),
1833
// count(column_name)
1834
(false, [FunctionArgExpr::Expr(sql_expr)]) => {
1835
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1836
expr.count()
1837
},
1838
// count(distinct column_name)
1839
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
1840
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1841
expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
1842
},
1843
_ => self.not_supported_error()?,
1844
};
1845
self.apply_window_spec(count_expr, &self.func.over)
1846
}
1847
1848
fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
1849
let mut by = Vec::with_capacity(order_by.len());
1850
let mut descending = Vec::with_capacity(order_by.len());
1851
let mut nulls_last = Vec::with_capacity(order_by.len());
1852
1853
for ob in order_by {
1854
// note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise
1855
// https://www.postgresql.org/docs/current/queries-order.html
1856
let desc_order = !ob.asc.unwrap_or(true);
1857
by.push(parse_sql_expr(&ob.expr, self.ctx, self.active_schema)?);
1858
nulls_last.push(!ob.nulls_first.unwrap_or(desc_order));
1859
descending.push(desc_order);
1860
}
1861
Ok(expr.sort_by(
1862
by,
1863
SortMultipleOptions::default()
1864
.with_order_descending_multi(descending)
1865
.with_nulls_last_multi(nulls_last)
1866
.with_maintain_order(true),
1867
))
1868
}
1869
1870
fn apply_window_spec(
1871
&mut self,
1872
expr: Expr,
1873
window_type: &Option<WindowType>,
1874
) -> PolarsResult<Expr> {
1875
Ok(match &window_type {
1876
Some(WindowType::WindowSpec(window_spec)) => {
1877
if window_spec.partition_by.is_empty() {
1878
let exprs = window_spec
1879
.order_by
1880
.iter()
1881
.map(|o| {
1882
let e = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
1883
Ok(o.asc.map_or(e.clone(), |b| {
1884
e.sort(SortOptions::default().with_order_descending(!b))
1885
}))
1886
})
1887
.collect::<PolarsResult<Vec<_>>>()?;
1888
expr.over(exprs)
1889
} else {
1890
// Process for simple window specification, partition by first
1891
let partition_by = window_spec
1892
.partition_by
1893
.iter()
1894
.map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
1895
.collect::<PolarsResult<Vec<_>>>()?;
1896
expr.over(partition_by)
1897
}
1898
},
1899
Some(WindowType::NamedWindow(named_window)) => polars_bail!(
1900
SQLInterface: "Named windows are not currently supported; found {:?}",
1901
named_window
1902
),
1903
None => expr,
1904
})
1905
}
1906
1907
fn not_supported_error(&self) -> PolarsResult<Expr> {
1908
polars_bail!(
1909
SQLInterface:
1910
"no function matches the given name and arguments: `{}`",
1911
self.func.to_string()
1912
);
1913
}
1914
}
1915
1916
fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
1917
let (args, _, _) = _extract_func_args(func, false, false)?;
1918
Ok(args)
1919
}
1920
1921
fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
1922
let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
1923
Ok((args, is_distinct))
1924
}
1925
1926
fn extract_args_and_clauses(
1927
func: &SQLFunction,
1928
) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
1929
_extract_func_args(func, true, true)
1930
}
1931
1932
fn _extract_func_args(
1933
func: &SQLFunction,
1934
get_distinct: bool,
1935
get_clauses: bool,
1936
) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
1937
match &func.args {
1938
FunctionArguments::List(FunctionArgumentList {
1939
args,
1940
duplicate_treatment,
1941
clauses,
1942
}) => {
1943
let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
1944
if !(get_clauses || get_distinct) && is_distinct {
1945
polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
1946
} else if !get_clauses && !clauses.is_empty() {
1947
polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
1948
} else {
1949
let unpacked_args = args
1950
.iter()
1951
.map(|arg| match arg {
1952
FunctionArg::Named { arg, .. } => arg,
1953
FunctionArg::ExprNamed { arg, .. } => arg,
1954
FunctionArg::Unnamed(arg) => arg,
1955
})
1956
.collect();
1957
Ok((unpacked_args, is_distinct, clauses.clone()))
1958
}
1959
},
1960
FunctionArguments::Subquery { .. } => {
1961
Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
1962
},
1963
FunctionArguments::None => Ok((vec![], false, vec![])),
1964
}
1965
}
1966
1967
pub(crate) trait FromSQLExpr {
1968
fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
1969
where
1970
Self: Sized;
1971
}
1972
1973
impl FromSQLExpr for f64 {
1974
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
1975
where
1976
Self: Sized,
1977
{
1978
match expr {
1979
SQLExpr::Value(v) => match v {
1980
SQLValue::Number(s, _) => s
1981
.parse()
1982
.map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
1983
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
1984
},
1985
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
1986
}
1987
}
1988
}
1989
1990
impl FromSQLExpr for bool {
1991
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
1992
where
1993
Self: Sized,
1994
{
1995
match expr {
1996
SQLExpr::Value(v) => match v {
1997
SQLValue::Boolean(v) => Ok(*v),
1998
_ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
1999
},
2000
_ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
2001
}
2002
}
2003
}
2004
2005
impl FromSQLExpr for String {
2006
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2007
where
2008
Self: Sized,
2009
{
2010
match expr {
2011
SQLExpr::Value(v) => match v {
2012
SQLValue::SingleQuotedString(s) => Ok(s.clone()),
2013
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2014
},
2015
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2016
}
2017
}
2018
}
2019
2020
impl FromSQLExpr for StrptimeOptions {
2021
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2022
where
2023
Self: Sized,
2024
{
2025
match expr {
2026
SQLExpr::Value(v) => match v {
2027
SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
2028
format: Some(PlSmallStr::from_str(s)),
2029
..StrptimeOptions::default()
2030
}),
2031
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2032
},
2033
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2034
}
2035
}
2036
}
2037
2038
impl FromSQLExpr for Expr {
2039
fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2040
where
2041
Self: Sized,
2042
{
2043
parse_sql_expr(expr, ctx, None)
2044
}
2045
}
2046
2047