Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-sql/tests/simple_exprs.rs
8384 views
1
use polars_core::prelude::*;
2
use polars_lazy::prelude::*;
3
use polars_sql::*;
4
use polars_time::Duration;
5
6
fn create_sample_df() -> DataFrame {
7
let a = Column::new(
8
"a".into(),
9
(1..10000i64).map(|i| i / 100).collect::<Vec<_>>(),
10
);
11
let b = Column::new("b".into(), 1..10000i64);
12
DataFrame::new_infer_height(vec![a, b]).unwrap()
13
}
14
15
fn create_struct_df() -> (DataFrame, DataFrame) {
16
let struct_cols = vec![col("num"), col("str"), col("val")];
17
let df = df! {
18
"num" => [100, 250, 300, 350],
19
"str" => ["b", "a", "b", "a"],
20
"val" => [4.0, 3.5, 2.0, 1.5],
21
}
22
.unwrap();
23
24
(
25
df.clone()
26
.lazy()
27
.select([as_struct(struct_cols).alias("json_msg")])
28
.collect()
29
.unwrap(),
30
df,
31
)
32
}
33
34
fn assert_sql_to_polars(df: &DataFrame, sql: &str, f: impl FnOnce(LazyFrame) -> LazyFrame) {
35
let mut context = SQLContext::new();
36
context.register("df", df.clone().lazy());
37
let df_sql = context.execute(sql).unwrap().collect().unwrap();
38
let df_pl = f(df.clone().lazy()).collect().unwrap();
39
assert!(df_sql.equals(&df_pl));
40
}
41
42
#[test]
43
fn test_simple_select() -> PolarsResult<()> {
44
let df = create_sample_df();
45
let mut context = SQLContext::new();
46
context.register("df", df.clone().lazy());
47
let df_sql = context
48
.execute(
49
r#"
50
SELECT a, b, a + b as c
51
FROM df
52
where a > 10 and a < 20
53
LIMIT 100
54
"#,
55
)?
56
.collect()?;
57
let df_pl = df
58
.lazy()
59
.filter(col("a").gt(lit(10)).and(col("a").lt(lit(20))))
60
.select(&[col("a"), col("b"), (col("a") + col("b")).alias("c")])
61
.limit(100)
62
.collect()?;
63
assert_eq!(df_sql, df_pl);
64
Ok(())
65
}
66
67
#[test]
68
fn test_nested_expr() -> PolarsResult<()> {
69
let df = create_sample_df();
70
let mut context = SQLContext::new();
71
context.register("df", df.clone().lazy());
72
let df_sql = context
73
.execute(r#"SELECT * FROM df WHERE (a > 3)"#)?
74
.collect()?;
75
let df_pl = df.lazy().filter(col("a").gt(lit(3))).collect()?;
76
assert_eq!(df_sql, df_pl);
77
Ok(())
78
}
79
80
#[test]
81
fn test_group_by_simple() -> PolarsResult<()> {
82
let df = create_sample_df();
83
let mut context = SQLContext::new();
84
context.register("df", df.clone().lazy());
85
let df_sql = context
86
.execute(
87
r#"
88
SELECT
89
a AS "aa",
90
SUM(b) AS "bb",
91
SUM(a + b) AS "cc",
92
COUNT(a) AS "count_a",
93
COUNT(*) AS "count_star"
94
FROM df
95
GROUP BY a
96
LIMIT 100
97
"#,
98
)?
99
.sort(["aa"], Default::default())
100
.collect()?;
101
102
let df_pl = df
103
.lazy()
104
.group_by(&[col("a").alias("aa")])
105
.agg(&[
106
col("b").sum().alias("bb"),
107
(col("a") + col("b")).sum().alias("cc"),
108
col("a").count().alias("count_a"),
109
col("a").len().alias("count_star"),
110
])
111
.limit(100)
112
.sort(["aa"], Default::default())
113
.collect()?;
114
assert_eq!(df_sql, df_pl);
115
Ok(())
116
}
117
118
#[test]
119
fn test_group_by_expression_key() -> PolarsResult<()> {
120
let df = df! {
121
"a" => &["xx", "yy", "xx", "yy", "xx", "zz"],
122
"b" => &[1, 2, 3, 4, 5, 6],
123
"c" => &[99, 99, 66, 66, 66, 66],
124
}
125
.unwrap();
126
127
let mut context = SQLContext::new();
128
context.register("df", df.lazy());
129
130
// check how we handle grouping by a key that gets used in select transform
131
let df_sql = context
132
.execute(
133
r#"
134
SELECT
135
CASE WHEN a = 'zz' THEN 'xx' ELSE a END AS grp,
136
SUM(b) AS sum_b,
137
SUM(c) AS sum_c,
138
FROM df
139
GROUP BY a
140
ORDER BY sum_c
141
"#,
142
)?
143
.sort(["sum_c"], Default::default())
144
.collect()?;
145
146
let df_expected = df! {
147
"grp" => ["xx", "yy", "xx"],
148
"sum_b" => [6, 6, 9],
149
"sum_c" => [66, 165, 231],
150
}
151
.unwrap();
152
153
assert_eq!(df_sql, df_expected);
154
Ok(())
155
}
156
157
#[test]
158
fn test_cast_exprs() {
159
let df = create_sample_df();
160
let mut context = SQLContext::new();
161
context.register("df", df.clone().lazy());
162
let sql = r#"
163
SELECT
164
cast(a as FLOAT) as f64,
165
cast(a as FLOAT(24)) as f32,
166
cast(a as INT) as ints,
167
cast(a as BIGINT) as bigints,
168
cast(a as STRING) as strings,
169
cast(a as BLOB) as binary
170
FROM df"#;
171
let df_sql = context.execute(sql).unwrap().collect().unwrap();
172
let df_pl = df
173
.lazy()
174
.select(&[
175
col("a").cast(DataType::Float64).alias("f64"),
176
col("a").cast(DataType::Float32).alias("f32"),
177
col("a").cast(DataType::Int32).alias("ints"),
178
col("a").cast(DataType::Int64).alias("bigints"),
179
col("a").cast(DataType::String).alias("strings"),
180
col("a").cast(DataType::Binary).alias("binary"),
181
])
182
.collect()
183
.unwrap();
184
assert!(df_sql.equals(&df_pl));
185
}
186
187
#[test]
188
fn test_literal_exprs() {
189
let df = create_sample_df();
190
let mut context = SQLContext::new();
191
context.register("df", df.clone().lazy());
192
let sql = r#"
193
SELECT
194
1 as int_lit,
195
1.0 as float_lit,
196
'foo' as string_lit,
197
true as bool_lit,
198
null as null_lit,
199
interval '2 weeks 1 day 50 seconds' as duration_lit
200
FROM df"#;
201
let df_sql = context.execute(sql).unwrap().collect().unwrap();
202
let df_pl = df
203
.lazy()
204
.select(&[
205
first().as_expr(),
206
lit(1i64).alias("int_lit"),
207
lit(1.0).alias("float_lit"),
208
lit("foo").alias("string_lit"),
209
lit(true).alias("bool_lit"),
210
lit(NULL).alias("null_lit"),
211
lit(Duration::parse("2w1d50s")).alias("duration_lit"),
212
])
213
.collect()
214
.unwrap()
215
.lazy()
216
.drop(first())
217
.collect()
218
.unwrap();
219
assert!(df_sql.equals_missing(&df_pl));
220
}
221
222
#[test]
223
fn test_implicit_date_string() {
224
let df = df! {
225
"idx" => &[Some(0), Some(1), Some(2), Some(3)],
226
"dt" => &[Some("1955-10-01"), None, Some("2007-07-05"), Some("2077-06-11")],
227
}
228
.unwrap()
229
.lazy()
230
.select(vec![col("idx"), col("dt").cast(DataType::Date)])
231
.collect()
232
.unwrap();
233
234
let mut context = SQLContext::new();
235
context.register("frame", df.clone().lazy());
236
for sql in [
237
"SELECT idx, dt FROM frame WHERE dt >= '2007-07-05'",
238
"SELECT idx, dt FROM frame WHERE dt::date >= '2007-07-05'",
239
"SELECT idx, dt FROM frame WHERE dt::datetime >= '2007-07-05 00:00:00'",
240
"SELECT idx, dt FROM frame WHERE dt::timestamp >= '2007-07-05 00:00:00'",
241
] {
242
let df_sql = context.execute(sql).unwrap().collect().unwrap();
243
let df_pl = df
244
.clone()
245
.lazy()
246
.filter(col("idx").gt_eq(lit(2)))
247
.collect()
248
.unwrap();
249
assert!(df_sql.equals(&df_pl));
250
}
251
}
252
253
#[test]
254
fn test_prefixed_column_names() {
255
let df = create_sample_df();
256
let mut context = SQLContext::new();
257
context.register("df", df.clone().lazy());
258
let sql = r#"
259
SELECT
260
df.a as a,
261
df.b as b
262
FROM df"#;
263
let df_sql = context.execute(sql).unwrap().collect().unwrap();
264
let df_pl = df
265
.lazy()
266
.select(&[col("a").alias("a"), col("b").alias("b")])
267
.collect()
268
.unwrap();
269
assert!(df_sql.equals(&df_pl));
270
}
271
272
#[test]
273
fn test_prefixed_column_names_2() {
274
let df = create_sample_df();
275
let mut context = SQLContext::new();
276
context.register("df", df.clone().lazy());
277
let sql = r#"
278
SELECT
279
"df"."a" as a,
280
"df"."b" as b
281
FROM df"#;
282
let df_sql = context.execute(sql).unwrap().collect().unwrap();
283
let df_pl = df
284
.lazy()
285
.select(&[col("a").alias("a"), col("b").alias("b")])
286
.collect()
287
.unwrap();
288
assert!(df_sql.equals(&df_pl));
289
}
290
291
#[test]
292
fn test_null_exprs() {
293
let df = create_sample_df();
294
let mut context = SQLContext::new();
295
context.register("df", df.clone().lazy());
296
let sql = r#"
297
SELECT
298
a,
299
b,
300
a is null as isnull_a,
301
b is null as isnull_b,
302
a is not null as isnotnull_a,
303
b is not null as isnotnull_b
304
FROM df"#;
305
let df_sql = context.execute(sql).unwrap().collect().unwrap();
306
let df_pl = df
307
.lazy()
308
.select(&[
309
col("a"),
310
col("b"),
311
col("a").is_null().alias("isnull_a"),
312
col("b").is_null().alias("isnull_b"),
313
col("a").is_not_null().alias("isnotnull_a"),
314
col("b").is_not_null().alias("isnotnull_b"),
315
])
316
.collect()
317
.unwrap();
318
assert!(df_sql.equals(&df_pl));
319
}
320
321
#[test]
322
fn test_null_exprs_in_where() {
323
let df = df! {
324
"a" => &[Some(1), None, Some(3)],
325
"b" => &[Some(1), Some(2), None]
326
}
327
.unwrap();
328
329
let mut context = SQLContext::new();
330
context.register("df", df.clone().lazy());
331
let sql = r#"
332
SELECT
333
a,
334
b
335
FROM df
336
WHERE a is null and b is not null"#;
337
let df_sql = context.execute(sql).unwrap().collect().unwrap();
338
let df_pl = df
339
.lazy()
340
.filter(col("a").is_null().and(col("b").is_not_null()))
341
.collect()
342
.unwrap();
343
344
assert!(df_sql.equals_missing(&df_pl));
345
}
346
347
#[test]
348
fn test_binary_functions() {
349
let df = create_sample_df();
350
let mut context = SQLContext::new();
351
context.register("df", df.clone().lazy());
352
let sql = r#"
353
SELECT
354
a,
355
b,
356
a + b AS add,
357
a - b AS sub,
358
a * b AS mul,
359
a / b AS div,
360
a % b AS rem,
361
a <> b AS neq,
362
a = b AS eq,
363
a > b AS gt,
364
a < b AS lt,
365
a >= b AS gte,
366
a <= b AS lte,
367
a and b AS and,
368
a or b AS or,
369
a xor b AS xor,
370
a || b AS concat
371
FROM df"#;
372
let df_sql = context.execute(sql).unwrap().collect().unwrap();
373
let df_pl = df.lazy().select(&[
374
col("a"),
375
col("b"),
376
(col("a") + col("b")).alias("add"),
377
(col("a") - col("b")).alias("sub"),
378
(col("a") * col("b")).alias("mul"),
379
(col("a") / col("b")).alias("div"),
380
(col("a") % col("b")).alias("rem"),
381
col("a").eq(col("b")).not().alias("neq"),
382
col("a").eq(col("b")).alias("eq"),
383
col("a").gt(col("b")).alias("gt"),
384
col("a").lt(col("b")).alias("lt"),
385
col("a").gt_eq(col("b")).alias("gte"),
386
col("a").lt_eq(col("b")).alias("lte"),
387
col("a").and(col("b")).alias("and"),
388
col("a").or(col("b")).alias("or"),
389
col("a").xor(col("b")).alias("xor"),
390
(col("a").cast(DataType::String) + col("b").cast(DataType::String)).alias("concat"),
391
]);
392
let df_pl = df_pl.collect().unwrap();
393
assert_eq!(df_sql, df_pl);
394
}
395
396
#[test]
397
#[ignore = "TODO: non deterministic"]
398
fn test_agg_functions() {
399
let df = create_sample_df();
400
let mut context = SQLContext::new();
401
context.register("df", df.clone().lazy());
402
let sql = r#"
403
SELECT
404
sum(a) AS sum_a,
405
first(a) AS first_a,
406
last(a) AS last_a,
407
avg(a) AS avg_a,
408
max(a) AS max_a,
409
min(a) AS min_a,
410
atan(a) AS atan_a,
411
stddev(a) AS stddev_a,
412
variance(a) AS variance_a,
413
count(a) AS count_a,
414
count(distinct a) AS count_distinct_a,
415
count(*) AS count_all
416
FROM df"#;
417
let df_sql = context.execute(sql).unwrap().collect().unwrap();
418
let df_pl = df
419
.lazy()
420
.select(&[
421
col("a").sum().alias("sum_a"),
422
col("a").first().alias("first_a"),
423
col("a").last().alias("last_a"),
424
col("a").mean().alias("avg_a"),
425
col("a").max().alias("max_a"),
426
col("a").min().alias("min_a"),
427
col("a").arctan().alias("atan_a"),
428
col("a").std(1).alias("stddev_a"),
429
col("a").var(1).alias("variance_a"),
430
col("a").count().alias("count_a"),
431
col("a").n_unique().alias("count_distinct_a"),
432
lit(1i32).count().alias("count_all"),
433
])
434
.collect()
435
.unwrap();
436
assert!(df_sql.equals(&df_pl));
437
}
438
439
#[test]
440
fn test_create_table() {
441
let df = create_sample_df();
442
let mut context = SQLContext::new();
443
context.register("df", df.clone().lazy());
444
445
let sql = r#"
446
CREATE TABLE df2 AS
447
SELECT a
448
FROM df"#;
449
let df_sql = context.execute(sql).unwrap().collect().unwrap();
450
let create_tbl_res = df! {
451
"Response" => ["CREATE TABLE df2"]
452
}
453
.unwrap();
454
455
assert!(df_sql.equals(&create_tbl_res));
456
let df_2 = context
457
.execute(r#"SELECT a FROM df2"#)
458
.unwrap()
459
.collect()
460
.unwrap();
461
462
let expected = df.lazy().select(&[col("a")]).collect().unwrap();
463
assert!(df_2.equals(&expected));
464
}
465
466
#[test]
467
fn test_unary_minus_0() {
468
let df = df! {
469
"value" => [-5, -3, 0, 3, 5],
470
}
471
.unwrap();
472
473
let mut context = SQLContext::new();
474
context.register("df", df.clone().lazy());
475
let sql = r#"SELECT * FROM df WHERE value < -1"#;
476
let df_sql = context.execute(sql).unwrap().collect().unwrap();
477
let df_pl = df
478
.lazy()
479
.filter(col("value").lt(lit(-1)))
480
.collect()
481
.unwrap();
482
483
assert!(df_sql.equals(&df_pl));
484
}
485
486
#[test]
487
fn test_unary_minus_1() {
488
let df = df! {
489
"value" => [-5, -3, 0, 3, 5],
490
}
491
.unwrap();
492
493
let mut context = SQLContext::new();
494
context.register("df", df.clone().lazy());
495
let sql = r#"SELECT * FROM df WHERE -value < 1"#;
496
let df_sql = context.execute(sql).unwrap().collect().unwrap();
497
let neg_value = lit(0) - col("value");
498
let df_pl = df.lazy().filter(neg_value.lt(lit(1))).collect().unwrap();
499
assert!(df_sql.equals(&df_pl));
500
}
501
502
#[test]
503
fn test_arr_agg() {
504
let df = create_sample_df();
505
let exprs = vec![
506
(
507
"SELECT ARRAY_AGG(a) AS a FROM df",
508
vec![col("a").implode().alias("a")],
509
),
510
(
511
"SELECT ARRAY_AGG(a) AS a, ARRAY_AGG(b) AS b FROM df",
512
vec![col("a").implode().alias("a"), col("b").implode().alias("b")],
513
),
514
(
515
"SELECT ARRAY_AGG(a ORDER BY a) AS a FROM df",
516
vec![
517
col("a")
518
.sort_by(vec![col("a")], SortMultipleOptions::default())
519
.implode()
520
.alias("a"),
521
],
522
),
523
(
524
"SELECT ARRAY_AGG(a) AS a FROM df",
525
vec![col("a").implode().alias("a")],
526
),
527
(
528
"SELECT UNNEST(ARRAY_AGG(DISTINCT a)) FROM df",
529
vec![
530
col("a")
531
.unique_stable()
532
.implode()
533
.explode(ExplodeOptions {
534
empty_as_null: true,
535
keep_nulls: true,
536
})
537
.alias("a"),
538
],
539
),
540
(
541
"SELECT ARRAY_AGG(a ORDER BY b LIMIT 2) FROM df",
542
vec![
543
col("a")
544
.sort_by(vec![col("b")], SortMultipleOptions::default())
545
.head(Some(2))
546
.implode(),
547
],
548
),
549
];
550
551
for (sql, expr) in exprs {
552
assert_sql_to_polars(&df, sql, |df| df.select(&expr));
553
}
554
}
555
556
#[test]
557
fn test_explode_with_multiple_columns() {
558
let df = create_sample_df();
559
560
// Implode column "a"
561
let df_imploded = df
562
.lazy()
563
.select(&[col("a").implode().alias("a")])
564
.collect()
565
.unwrap();
566
567
let df_with_new_column_a = df_imploded
568
.clone()
569
.lazy()
570
.with_column(lit("a").alias("b"))
571
.collect()
572
.unwrap();
573
574
let df_with_new_column_b = df_imploded
575
.lazy()
576
.with_column(lit("b").alias("b"))
577
.collect()
578
.unwrap();
579
580
let df = df_with_new_column_a.vstack(&df_with_new_column_b).unwrap();
581
let df_pl_api = df.clone().lazy().explode(
582
polars_lazy::dsl::Selector::ByName {
583
names: Arc::from(vec!["a".into()]),
584
strict: true,
585
},
586
ExplodeOptions {
587
empty_as_null: true,
588
keep_nulls: true,
589
},
590
);
591
let mut context = SQLContext::new();
592
context.register("df", df.lazy());
593
594
let sql = r#"
595
SELECT
596
unnest(a) AS a,
597
b
598
FROM df
599
"#;
600
601
let df_sql = context.execute(sql).unwrap().collect().unwrap();
602
let df_pl_api = df_pl_api.collect().unwrap();
603
assert!(df_sql.equals(&df_pl_api));
604
}
605
606
#[test]
607
fn test_multiple_explodes_with_same_column() {
608
let df = create_sample_df();
609
let df_imploded = df
610
.lazy()
611
.select(&[
612
col("a").implode().alias("list_a"),
613
col("b").implode().alias("list_b"),
614
])
615
.collect()
616
.unwrap();
617
618
let mut context = SQLContext::new();
619
context.register("df", df_imploded.clone().lazy());
620
let sql = r#"
621
SELECT
622
unnest(list_a) AS list_a,
623
unnest(list_b) AS list_b,
624
CASE
625
WHEN unnest(list_b) > 5000 THEN 'High'
626
WHEN unnest(list_b) > 2500 THEN 'Medium'
627
ELSE 'Low'
628
END AS list_b_category
629
FROM df
630
"#;
631
let df_sql = context.execute(sql).unwrap().collect().unwrap();
632
633
let expected_list_b: Vec<i64> = (1..10000).collect();
634
let expected_list_a: Vec<i64> = expected_list_b.iter().map(|b| b / 100).collect();
635
636
let expected_category: Vec<&'static str> = expected_list_b
637
.iter()
638
.map(|b| {
639
if *b > 5000 {
640
"High"
641
} else if *b > 2500 {
642
"Medium"
643
} else {
644
"Low"
645
}
646
})
647
.collect();
648
649
let expected_df = DataFrame::new_infer_height(vec![
650
Column::new(PlSmallStr::from_static("list_a"), expected_list_a),
651
Column::new(PlSmallStr::from_static("list_b"), expected_list_b),
652
Column::new(
653
PlSmallStr::from_static("list_b_category"),
654
expected_category,
655
),
656
])
657
.unwrap();
658
assert!(df_sql.equals(&expected_df));
659
assert!(df_sql.shape().eq(&(9_999, 3)));
660
}
661
662
#[test]
663
fn test_multiple_explodes_different_columns() {
664
let df = create_sample_df();
665
let df_imploded = df
666
.lazy()
667
.select(&[
668
col("a").implode().alias("list_a"),
669
col("b").implode().alias("list_b"),
670
])
671
.collect()
672
.unwrap();
673
674
// Add scalar to check if row-bound mapping stays consistent.
675
let df_with_scalar = df_imploded
676
.lazy()
677
.with_column(lit(100).alias("value"))
678
.collect()
679
.unwrap();
680
681
// Test using both the Polars API and SQL
682
let df_pl_api = df_with_scalar
683
.clone()
684
.lazy()
685
.explode(
686
polars_lazy::dsl::Selector::ByName {
687
names: Arc::from(vec!["list_a".into(), "list_b".into()]),
688
strict: true,
689
},
690
ExplodeOptions {
691
empty_as_null: true,
692
keep_nulls: true,
693
},
694
)
695
.collect()
696
.unwrap();
697
698
let mut context = SQLContext::new();
699
context.register("df", df_with_scalar.clone().lazy());
700
701
let sql = r#"
702
SELECT
703
unnest(list_a) AS list_a,
704
unnest(list_b) AS list_b,
705
value
706
FROM df
707
"#;
708
709
let df_sql = context.execute(sql).unwrap().collect().unwrap();
710
711
let expected_list_a: Vec<i64> = (1..10000).map(|i| i / 100).collect();
712
let expected_list_b: Vec<i64> = (1..10000).collect();
713
let expected_value: Vec<i32> = vec![100i32; expected_list_b.len()];
714
715
let expected_df = DataFrame::new_infer_height(vec![
716
Column::new(PlSmallStr::from_static("list_a"), expected_list_a),
717
Column::new(PlSmallStr::from_static("list_b"), expected_list_b),
718
Column::new(PlSmallStr::from_static("value"), expected_value),
719
])
720
.unwrap();
721
722
assert!(df_sql.equals(&df_pl_api));
723
assert!(df_sql.equals(&expected_df));
724
assert!(df_pl_api.equals(&expected_df));
725
}
726
727
#[test]
728
fn explode_same_name_with_cte() {
729
let values = vec![
730
Series::new(
731
PlSmallStr::from_static(""),
732
vec![
733
Series::new(PlSmallStr::from_static(""), &[1i64, 2]),
734
Series::new(PlSmallStr::from_static(""), &[3i64, 4]),
735
],
736
),
737
Series::new(
738
PlSmallStr::from_static(""),
739
vec![
740
Series::new(PlSmallStr::from_static(""), &[5i64, 6]),
741
Series::new(PlSmallStr::from_static(""), &[7i64, 8]),
742
],
743
),
744
];
745
746
let list_series = Column::new(PlSmallStr::from_static("list_a"), values);
747
748
let df = DataFrame::new_infer_height(vec![list_series]).unwrap();
749
750
let df_imploded = df
751
.lazy()
752
.select(&[col("list_a").implode().alias("list_a")])
753
.collect()
754
.unwrap();
755
756
let mut context = SQLContext::new();
757
context.register("df", df_imploded.clone().lazy());
758
759
let sql = r#"
760
WITH exploded AS (
761
SELECT
762
unnest(list_a) AS list_a
763
FROM df
764
),
765
exploded_2 AS (
766
SELECT
767
unnest(list_a) AS list_a
768
FROM exploded
769
)
770
SELECT
771
unnest(list_a) AS list_a
772
FROM exploded_2
773
"#;
774
775
let df_sql = context.execute(sql).unwrap().collect().unwrap();
776
let df_pl_api = df_imploded
777
.lazy()
778
.explode(
779
polars_lazy::dsl::Selector::ByName {
780
names: Arc::from(vec!["list_a".into()]),
781
strict: true,
782
},
783
ExplodeOptions {
784
empty_as_null: true,
785
keep_nulls: true,
786
},
787
)
788
.explode(
789
polars_lazy::dsl::Selector::ByName {
790
names: Arc::from(vec!["list_a".into()]),
791
strict: true,
792
},
793
ExplodeOptions {
794
empty_as_null: true,
795
keep_nulls: true,
796
},
797
)
798
.explode(
799
polars_lazy::dsl::Selector::ByName {
800
names: Arc::from(vec!["list_a".into()]),
801
strict: true,
802
},
803
ExplodeOptions {
804
empty_as_null: true,
805
keep_nulls: true,
806
},
807
)
808
.collect()
809
.unwrap();
810
811
let expected_results = vec![1i64, 2, 3, 4, 5, 6, 7, 8];
812
let expected_df = DataFrame::new_infer_height(vec![Column::new(
813
PlSmallStr::from_static("list_a"),
814
expected_results,
815
)])
816
.unwrap();
817
assert!(df_sql.equals(&df_pl_api));
818
assert!(df_sql.equals(&expected_df));
819
assert!(df_pl_api.equals(&expected_df));
820
}
821
822
#[test]
823
fn test_ctes() -> PolarsResult<()> {
824
let df = create_sample_df();
825
826
let mut context = SQLContext::new();
827
context.register("df", df.lazy());
828
829
// note: confirm correct behaviour of quoted/unquoted CTE identifiers
830
let sql0 = r#"WITH "df0" AS (SELECT * FROM "df") SELECT * FROM df0 "#;
831
assert!(context.execute(sql0).is_ok());
832
833
let sql1 = r#"WITH df0 AS (SELECT * FROM df) SELECT * FROM "df0" "#;
834
assert!(context.execute(sql1).is_ok());
835
836
let sql2 = r#"SELECT * FROM df0"#;
837
assert!(context.execute(sql2).is_err());
838
839
Ok(())
840
}
841
842
#[test]
843
fn test_cte_values() -> PolarsResult<()> {
844
let sql = r#"
845
WITH
846
x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)),
847
y (m, n) AS (
848
WITH z(c, d) AS (SELECT a, b FROM x)
849
SELECT d*2 AS d2, c*3 AS c3 FROM z
850
)
851
SELECT n, m FROM y
852
"#;
853
let mut context = SQLContext::new();
854
assert!(context.execute(sql).is_ok());
855
856
Ok(())
857
}
858
859
#[test]
860
#[cfg(feature = "ipc")]
861
fn test_group_by_2() -> PolarsResult<()> {
862
use polars_utils::pl_path::PlRefPath;
863
864
let mut context = SQLContext::new();
865
let sql = r#"
866
CREATE TABLE foods AS
867
SELECT *
868
FROM read_ipc('../../examples/datasets/foods1.ipc')"#;
869
870
context.execute(sql)?.collect()?;
871
let sql = r#"
872
SELECT
873
category,
874
count(category) AS count,
875
max(calories),
876
min(fats_g)
877
FROM foods
878
GROUP BY category
879
ORDER BY count, category DESC
880
LIMIT 2"#;
881
882
let df_sql = context.execute(sql)?;
883
let df_sql = df_sql.collect()?;
884
let expected = LazyFrame::scan_ipc(
885
PlRefPath::new("../../examples/datasets/foods1.ipc"),
886
Default::default(),
887
Default::default(),
888
)?
889
.select(&[col("*")])
890
.group_by(vec![col("category")])
891
.agg(vec![
892
col("category").count().alias("count"),
893
col("calories").max(),
894
col("fats_g").min(),
895
])
896
.sort_by_exprs(
897
vec![col("count"), col("category")],
898
SortMultipleOptions::default().with_order_descending_multi([false, true]),
899
)
900
.limit(2);
901
902
let expected = expected.collect()?;
903
assert!(df_sql.equals(&expected));
904
Ok(())
905
}
906
907
#[test]
908
fn test_case_expr() {
909
let df = create_sample_df().head(Some(10));
910
let mut context = SQLContext::new();
911
context.register("df", df.clone().lazy());
912
let sql = r#"
913
SELECT
914
CASE
915
WHEN (a > 5 AND a < 8) THEN 'gt_5 and lt_8'
916
WHEN a <= 5 THEN 'lteq_5'
917
ELSE 'no match'
918
END AS sign
919
FROM df"#;
920
let df_sql = context.execute(sql).unwrap().collect().unwrap();
921
let case_expr = when(col("a").gt(lit(5)).and(col("a").lt(lit(8))))
922
.then(lit("gt_5 and lt_8"))
923
.when(col("a").lt_eq(lit(5)))
924
.then(lit("lteq_5"))
925
.otherwise(lit("no match"))
926
.alias("sign");
927
928
let df_pl = df.lazy().select(&[case_expr]).collect().unwrap();
929
assert!(df_sql.equals(&df_pl));
930
}
931
932
#[test]
933
fn test_case_expr_with_expression() {
934
let df = create_sample_df();
935
let mut context = SQLContext::new();
936
context.register("df", df.clone().lazy());
937
938
let sql = r#"
939
SELECT
940
CASE b%2
941
WHEN 0 THEN 'even'
942
WHEN 1 THEN 'odd'
943
ELSE 'No?'
944
END AS parity
945
FROM df"#;
946
let df_sql = context.execute(sql).unwrap().collect().unwrap();
947
let case_expr = when((col("b") % lit(2)).eq(lit(0)))
948
.then(lit("even"))
949
.when((col("b") % lit(2)).eq(lit(1)))
950
.then(lit("odd"))
951
.otherwise(lit("No?"))
952
.alias("parity");
953
954
let df_pl = df.lazy().select(&[case_expr]).collect().unwrap();
955
assert!(df_sql.equals(&df_pl));
956
}
957
958
#[test]
959
fn test_sql_expr() {
960
let df = create_sample_df();
961
let expr = sql_expr("MIN(a)").unwrap();
962
let actual = df.clone().lazy().select(&[expr]).collect().unwrap();
963
let expected = df.lazy().select(&[col("a").min()]).collect().unwrap();
964
assert!(actual.equals(&expected));
965
}
966
967
#[test]
968
fn test_iss_9471() {
969
let df = df! {
970
"a" => [-4, -3, -2, -1, 0, 1, 2, 3, 4],
971
}
972
.unwrap()
973
.lazy();
974
975
let mut context = SQLContext::new();
976
context.register("df", df);
977
978
let sql = r#"
979
SELECT
980
ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") AS "abs",
981
FROM df"#;
982
let res = context.execute(sql);
983
984
assert!(res.is_err())
985
}
986
987
#[test]
988
fn test_order_by_excluded_column() {
989
let df = df! {
990
"x" => [0, 1, 2, 3],
991
"y" => [4, 2, 0, 8],
992
}
993
.unwrap()
994
.lazy();
995
996
let mut context = SQLContext::new();
997
context.register("df", df);
998
999
for sql in [
1000
"SELECT * EXCLUDE y FROM df ORDER BY y",
1001
"SELECT df.* EXCLUDE y FROM df ORDER BY y",
1002
] {
1003
let df_sorted = context.execute(sql).unwrap().collect().unwrap();
1004
let expected = df! {"x" => [2, 1, 0, 3],}.unwrap();
1005
assert!(df_sorted.equals(&expected));
1006
}
1007
}
1008
1009
#[test]
1010
fn test_struct_field_selection() {
1011
let (df_struct, df_original) = create_struct_df();
1012
1013
let mut context = SQLContext::new();
1014
context.register("df", df_struct.lazy());
1015
1016
for sql in [
1017
r#"SELECT json_msg.* FROM df ORDER BY 1"#,
1018
r#"SELECT df.json_msg.* FROM df ORDER BY 3 DESC"#,
1019
r#"SELECT json_msg.* FROM df ORDER BY json_msg.num"#,
1020
r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.val DESC"#,
1021
] {
1022
let df_sql = context.execute(sql).unwrap().collect().unwrap();
1023
assert!(df_sql.equals(&df_original));
1024
}
1025
1026
let sql = r#"
1027
SELECT
1028
json_msg.str AS id,
1029
SUM(json_msg -> 'num') AS sum_n
1030
FROM df
1031
GROUP BY json_msg.str
1032
ORDER BY 1
1033
"#;
1034
let df_sql = context.execute(sql).unwrap().collect().unwrap();
1035
let df_expected = df! {
1036
"id" => ["a", "b"],
1037
"sum_n" => [600, 400],
1038
}
1039
.unwrap();
1040
assert!(df_sql.equals(&df_expected));
1041
}
1042
1043