Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_miscellaneous.py
8413 views
1
from __future__ import annotations
2
3
from datetime import date
4
from pathlib import Path
5
from typing import TYPE_CHECKING, Any
6
7
import pytest
8
9
import polars as pl
10
from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError
11
from polars.testing import assert_frame_equal
12
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder
13
14
if TYPE_CHECKING:
15
from polars.datatypes import DataType
16
17
18
@pytest.fixture
19
def foods_ipc_path() -> Path:
20
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"
21
22
23
def test_any_all() -> None:
24
df = pl.DataFrame(
25
{
26
"x": [-1, 0, 1, 2, 3, 4],
27
"y": [1, 0, 0, 1, 2, 3],
28
}
29
)
30
res = pl.sql(
31
"""
32
SELECT
33
x >= ALL(df.y) AS "All Geq",
34
x > ALL(df.y) AS "All G",
35
x < ALL(df.y) AS "All L",
36
x <= ALL(df.y) AS "All Leq",
37
x >= ANY(df.y) AS "Any Geq",
38
x > ANY(df.y) AS "Any G",
39
x < ANY(df.y) AS "Any L",
40
x <= ANY(df.y) AS "Any Leq",
41
x == ANY(df.y) AS "Any eq",
42
x != ANY(df.y) AS "Any Neq",
43
FROM df
44
""",
45
).collect()
46
47
assert res.to_dict(as_series=False) == {
48
"All Geq": [0, 0, 0, 0, 1, 1],
49
"All G": [0, 0, 0, 0, 0, 1],
50
"All L": [1, 0, 0, 0, 0, 0],
51
"All Leq": [1, 1, 0, 0, 0, 0],
52
"Any Geq": [0, 1, 1, 1, 1, 1],
53
"Any G": [0, 0, 1, 1, 1, 1],
54
"Any L": [1, 1, 1, 1, 0, 0],
55
"Any Leq": [1, 1, 1, 1, 1, 0],
56
"Any eq": [0, 1, 1, 1, 1, 0],
57
"Any Neq": [1, 0, 0, 0, 0, 1],
58
}
59
60
61
@pytest.mark.parametrize(
62
("data", "schema"),
63
[
64
({"x": [1, 2, 3, 4]}, None),
65
({"x": [9, 8, 7, 6]}, {"x": pl.Int8}),
66
({"x": ["aa", "bb"]}, {"x": pl.Struct}),
67
({"x": [None, None], "y": [None, None]}, {"x": pl.Date, "y": pl.Float64}),
68
],
69
)
70
def test_boolean_where_clauses(
71
data: dict[str, Any], schema: dict[str, DataType] | None
72
) -> None:
73
df = pl.DataFrame(data=data, schema=schema)
74
empty_df = df.clear()
75
76
for true in ("TRUE", "1=1", "2 == 2", "'xx' = 'xx'", "TRUE AND 1=1"):
77
assert_frame_equal(df, df.sql(f"SELECT * FROM self WHERE {true}"))
78
79
for false in ("false", "1!=1", "2 != 2", "'xx' != 'xx'", "FALSE OR 1!=1"):
80
assert_frame_equal(empty_df, df.sql(f"SELECT * FROM self WHERE {false}"))
81
82
83
def test_count() -> None:
84
df = pl.DataFrame(
85
{
86
"a": [1, 2, 3, 4, 5],
87
"b": [1, 1, 22, 22, 333],
88
"c": [1, 1, None, None, 2],
89
}
90
)
91
res = df.sql(
92
"""
93
SELECT
94
-- count
95
COUNT(a) AS count_a,
96
COUNT(b) AS count_b,
97
COUNT(c) AS count_c,
98
COUNT(*) AS count_star,
99
COUNT(NULL) AS count_null,
100
-- count distinct
101
COUNT(DISTINCT a) AS count_unique_a,
102
COUNT(DISTINCT b) AS count_unique_b,
103
COUNT(DISTINCT c) AS count_unique_c,
104
COUNT(DISTINCT NULL) AS count_unique_null,
105
FROM self
106
""",
107
)
108
assert res.to_dict(as_series=False) == {
109
"count_a": [5],
110
"count_b": [5],
111
"count_c": [3],
112
"count_star": [5],
113
"count_null": [0],
114
"count_unique_a": [5],
115
"count_unique_b": [3],
116
"count_unique_c": [2],
117
"count_unique_null": [0],
118
}
119
120
df = pl.DataFrame({"x": [None, None, None]})
121
res = df.sql(
122
"""
123
SELECT
124
COUNT(x) AS count_x,
125
COUNT(*) AS count_star,
126
COUNT(DISTINCT x) AS count_unique_x
127
FROM self
128
"""
129
)
130
assert res.to_dict(as_series=False) == {
131
"count_x": [0],
132
"count_star": [3],
133
"count_unique_x": [0],
134
}
135
136
137
def test_cte_aliasing() -> None:
138
df1 = pl.DataFrame({"colx": ["aa", "bb"], "coly": [40, 30]})
139
df2 = pl.DataFrame({"colx": "aa", "colz": 20})
140
df3 = pl.sql(
141
query="""
142
WITH
143
test1 AS (SELECT * FROM df1),
144
test2 AS (SELECT * FROM df2),
145
test3 AS (
146
SELECT ROW_NUMBER() OVER (ORDER BY t1.colx) AS n, t1.colx, t2.colz
147
FROM test1 t1
148
LEFT JOIN test2 t2 ON t1.colx = t2.colx
149
)
150
SELECT * FROM test3 t3 ORDER BY colx DESC
151
""",
152
eager=True,
153
)
154
expected = [(2, "bb", None), (1, "aa", 20)]
155
assert expected == df3.rows()
156
157
158
def test_distinct() -> None:
159
df = pl.DataFrame(
160
{
161
"a": [1, 1, 1, 2, 2, 3],
162
"b": [1, 2, 3, 4, 5, 6],
163
}
164
)
165
ctx = pl.SQLContext(register_globals=True, eager=True)
166
res1 = ctx.execute("SELECT DISTINCT a FROM df ORDER BY a DESC")
167
assert_frame_equal(
168
left=df.select("a").unique().sort(by="a", descending=True),
169
right=res1,
170
)
171
172
res2 = ctx.execute(
173
"""
174
SELECT DISTINCT
175
a * 2 AS two_a,
176
b / 2 AS half_b
177
FROM df
178
ORDER BY two_a ASC, half_b DESC
179
""",
180
)
181
assert res2.to_dict(as_series=False) == {
182
"two_a": [2, 2, 4, 6],
183
"half_b": [1, 0, 2, 3],
184
}
185
186
# test unregistration
187
ctx.unregister("df")
188
with pytest.raises(SQLInterfaceError, match="relation 'df' was not found"):
189
ctx.execute("SELECT * FROM df")
190
191
192
def test_frame_sql_globals_error() -> None:
193
df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
194
df2 = pl.DataFrame({"a": [2, 3, 4], "b": [7, 6, 5]})
195
196
query = """
197
SELECT df1.a, df2.b
198
FROM df2 JOIN df1 ON df1.a = df2.a
199
ORDER BY b DESC
200
"""
201
with pytest.raises(SQLInterfaceError, match=r"relation.*not found.*"):
202
df1.sql(query=query)
203
204
res = pl.sql(query=query, eager=True)
205
assert res.to_dict(as_series=False) == {"a": [2, 3], "b": [7, 6]}
206
207
208
def test_global_misc_lookup() -> None:
209
# check that `col` in global namespace is not incorrectly identified
210
# as supporting pycapsule (as it can look like it has *any* attr)
211
from polars import col # noqa: F401
212
213
df = pl.DataFrame({"col": [90, 80, 70]})
214
df_res = pl.sql("SELECT col FROM df WHERE col > 75", eager=True)
215
assert df_res.rows() == [(90,), (80,)]
216
217
218
def test_in_no_ops_11946() -> None:
219
lf = pl.LazyFrame(
220
[
221
{"i1": 1},
222
{"i1": 2},
223
{"i1": 3},
224
]
225
)
226
out = lf.sql(
227
query="SELECT * FROM frame_data WHERE i1 in (1, 3)",
228
table_name="frame_data",
229
).collect()
230
assert out.to_dict(as_series=False) == {"i1": [1, 3]}
231
232
233
def test_limit_offset() -> None:
234
n_values = 11
235
lf = pl.LazyFrame({"a": range(n_values), "b": reversed(range(n_values))})
236
ctx = pl.SQLContext(tbl=lf)
237
238
assert ctx.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [
239
(4, 6),
240
(5, 5),
241
(6, 4),
242
]
243
for offset, limit in [(0, 3), (1, n_values), (2, 3), (5, 3), (8, 5), (n_values, 1)]:
244
out = ctx.execute(
245
f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True
246
)
247
assert_frame_equal(out, lf.slice(offset, limit).collect())
248
assert len(out) == min(limit, n_values - offset)
249
250
251
def test_nested_subquery_table_leakage() -> None:
252
a = pl.LazyFrame({"id": [1, 2, 3]})
253
b = pl.LazyFrame({"val": [2, 3, 4]})
254
255
ctx = pl.SQLContext(a=a, b=b)
256
ctx.execute("""
257
SELECT *
258
FROM a
259
WHERE id IN (
260
SELECT derived.val
261
FROM (SELECT val FROM b) AS derived
262
)
263
""")
264
265
# after execution of the above query, confirm that we don't see the
266
# inner "derived" table alias still being registered in the context
267
with pytest.raises(
268
SQLInterfaceError,
269
match="relation 'derived' was not found",
270
):
271
ctx.execute("SELECT * FROM derived")
272
273
274
def test_register_context() -> None:
275
# context manager usage should unregister tables created in each
276
# scope on context exit; supports arbitrary levels of nesting.
277
with pl.SQLContext() as ctx:
278
_lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["m", "n", "o"]})
279
_lf2 = pl.LazyFrame({"a": [2, 3, 4], "c": ["p", "q", "r"]})
280
281
ctx.register_globals()
282
assert ctx.tables() == ["_lf1", "_lf2"]
283
284
with ctx:
285
_lf3 = pl.LazyFrame({"a": [3, 4, 5], "b": ["s", "t", "u"]})
286
_lf4 = pl.LazyFrame({"a": [4, 5, 6], "c": ["v", "w", "x"]})
287
ctx.register_globals(n=2)
288
assert ctx.tables() == ["_lf1", "_lf2", "_lf3", "_lf4"]
289
290
assert ctx.tables() == ["_lf1", "_lf2"]
291
292
assert ctx.tables() == []
293
294
295
def test_sql_on_compatible_frame_types() -> None:
296
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
297
298
# create various different frame types
299
dfp = df.to_pandas()
300
dfa = df.to_arrow()
301
dfb = dfa.to_batches()[0]
302
dfo = PyCapsuleStreamHolder(df)
303
304
# run polars sql query against all frame types
305
for dfs in ( # noqa: B007
306
(df["a"] * 2).rename("c"), # polars series
307
(dfp["a"] * 2).rename("c"), # pandas series
308
):
309
res = pl.sql(
310
"""
311
SELECT a, b, SUM(c) AS cc FROM (
312
SELECT * FROM df -- polars frame
313
UNION ALL SELECT * FROM dfp -- pandas frame
314
UNION ALL SELECT * FROM dfa -- pyarrow table
315
UNION ALL SELECT * FROM dfb -- pyarrow record batch
316
UNION ALL SELECT * FROM dfo -- arbitrary pycapsule object
317
) tbl
318
INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series
319
GROUP BY "a", "b"
320
ORDER BY "a", "b"
321
"""
322
).collect()
323
324
expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [20, 30]})
325
assert_frame_equal(left=expected, right=res)
326
327
# register and operate on non-polars frames
328
for obj in (dfa, dfp):
329
with pl.SQLContext(obj=obj) as ctx:
330
res = ctx.execute("SELECT * FROM obj", eager=True)
331
assert_frame_equal(df, res)
332
333
# don't register all compatible objects
334
with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"):
335
pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp")
336
337
338
def test_nested_cte_column_aliasing() -> None:
339
# trace through nested CTEs with multiple levels of column & table aliasing
340
df = pl.sql(
341
"""
342
WITH
343
x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)),
344
y (m, n) AS (
345
WITH z(c, d) AS (SELECT a, b FROM x)
346
SELECT d*2 AS d2, c*3 AS c3 FROM z
347
)
348
SELECT n, m FROM y
349
""",
350
eager=True,
351
)
352
assert df.to_dict(as_series=False) == {
353
"n": [3, 9],
354
"m": [4, 8],
355
}
356
357
358
def test_invalid_derived_table_column_aliases() -> None:
359
values_query = "SELECT * FROM (VALUES (1,2), (3,4))"
360
361
with pytest.raises(
362
SQLSyntaxError,
363
match=r"columns \(5\) in alias 'tbl' does not match .* the table/query \(2\)",
364
):
365
pl.sql(f"{values_query} AS tbl(a, b, c, d, e)")
366
367
assert pl.sql(f"{values_query} tbl", eager=True).rows() == [(1, 2), (3, 4)]
368
369
370
def test_values_clause_table_registration() -> None:
371
with pl.SQLContext(frames=None, eager=True) as ctx:
372
# initially no tables are registered
373
assert ctx.tables() == []
374
375
# confirm that VALUES clause derived table is registered, post-query
376
res1 = ctx.execute("SELECT * FROM (VALUES (-1,1)) AS tbl(x, y)")
377
assert ctx.tables() == ["tbl"]
378
379
# and confirm that we can select from it by the registered name
380
res2 = ctx.execute("SELECT x, y FROM tbl")
381
for res in (res1, res2):
382
assert res.to_dict(as_series=False) == {"x": [-1], "y": [1]}
383
384
385
def test_read_csv(tmp_path: Path) -> None:
386
# check empty string vs null, parsing of dates, etc
387
df = pl.DataFrame(
388
{
389
"label": ["lorem", None, "", "ipsum"],
390
"num": [-1, None, 0, 1],
391
"dt": [
392
date(1969, 7, 5),
393
date(1999, 12, 31),
394
date(2077, 10, 10),
395
None,
396
],
397
}
398
)
399
csv_target = tmp_path / "test_sql_read.csv"
400
df.write_csv(csv_target)
401
402
res = pl.sql(f"SELECT * FROM read_csv('{csv_target}')").collect()
403
assert_frame_equal(df, res)
404
405
with pytest.raises(
406
SQLSyntaxError,
407
match="`read_csv` expects a single file path; found 3 arguments",
408
):
409
pl.sql("SELECT * FROM read_csv('a','b','c')")
410
411
412
def test_global_variable_inference_17398() -> None:
413
users = pl.DataFrame({"id": "1"})
414
415
res = pl.sql(
416
query="""
417
WITH user_by_email AS (SELECT id FROM users)
418
SELECT * FROM user_by_email
419
""",
420
eager=True,
421
)
422
assert_frame_equal(res, users)
423
424
425
@pytest.mark.parametrize(
426
"query",
427
[
428
"SELECT invalid_column FROM self",
429
"SELECT key, invalid_column FROM self",
430
"SELECT invalid_column * 2 FROM self",
431
"SELECT * FROM self ORDER BY invalid_column",
432
"SELECT * FROM self WHERE invalid_column = 200",
433
"SELECT * FROM self WHERE invalid_column = '200'",
434
"SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column",
435
],
436
)
437
def test_invalid_cols(query: str) -> None:
438
df = pl.DataFrame(
439
{
440
"key": ["xx", "xx", "yy"],
441
"n": ["100", "200", "300"],
442
}
443
)
444
with pytest.raises(ColumnNotFoundError, match="invalid_column"):
445
df.sql(query)
446
447
448
@pytest.mark.parametrize("filter_expr", ["", "WHERE 1 = 1", "WHERE a == 1 OR a != 1"])
449
@pytest.mark.parametrize("order_expr", ["", "ORDER BY 1", "ORDER BY a"])
450
def test_select_output_heights_20058_21084(filter_expr: str, order_expr: str) -> None:
451
df = pl.DataFrame({"a": [1, 2, 3]})
452
453
# Queries that maintain original height
454
455
assert_frame_equal(
456
df.sql(f"SELECT 1 as a FROM self {filter_expr} {order_expr}").cast(pl.Int64),
457
pl.select(a=pl.Series([1, 1, 1])),
458
)
459
460
assert_frame_equal(
461
df.sql(f"SELECT 1 + 1 as a, 1 as b FROM self {filter_expr} {order_expr}").cast(
462
pl.Int64
463
),
464
pl.DataFrame({"a": [2, 2, 2], "b": [1, 1, 1]}),
465
)
466
467
# Queries that aggregate to unit height
468
469
assert_frame_equal(
470
df.sql(f"SELECT COUNT(*) as a FROM self {filter_expr} {order_expr}").cast(
471
pl.Int64
472
),
473
pl.DataFrame({"a": 3}),
474
)
475
476
assert_frame_equal(
477
df.sql(
478
f"SELECT COUNT(*) as a, 1 as b FROM self {filter_expr} {order_expr}"
479
).cast(pl.Int64),
480
pl.DataFrame({"a": 3, "b": 1}),
481
)
482
483
assert_frame_equal(
484
df.sql(
485
f"SELECT FIRST(a) as a, 1 as b FROM self {filter_expr} {order_expr}"
486
).cast(pl.Int64),
487
pl.DataFrame({"a": 1, "b": 1}),
488
)
489
490
assert_frame_equal(
491
df.sql(f"SELECT SUM(a) as a, 1 as b FROM self {filter_expr} {order_expr}").cast(
492
pl.Int64
493
),
494
pl.DataFrame({"a": 6, "b": 1}),
495
)
496
497
assert_frame_equal(
498
df.sql(
499
f"SELECT FIRST(1) as a, 1 as b FROM self {filter_expr} {order_expr}"
500
).cast(pl.Int64),
501
pl.DataFrame({"a": 1, "b": 1}),
502
)
503
504
assert_frame_equal(
505
df.sql(
506
f"SELECT FIRST(1) + 1 as a, 1 as b FROM self {filter_expr} {order_expr}"
507
).cast(pl.Int64),
508
pl.DataFrame({"a": 2, "b": 1}),
509
)
510
511
assert_frame_equal(
512
df.sql(
513
f"SELECT FIRST(1 + 1) as a, 1 as b FROM self {filter_expr} {order_expr}"
514
).cast(pl.Int64),
515
pl.DataFrame({"a": 2, "b": 1}),
516
)
517
518
519
def test_select_explode_height_filter_order_by() -> None:
520
# Note: `unnest()` from SQL equates to `pl.Dataframe.explode()
521
# The ordering is applied after the explosion/unnest.
522
# `
523
df = pl.DataFrame(
524
{
525
"list_long": [[1, 2, 3], [4, 5, 6]],
526
"sort_key": [2, 1],
527
"filter_mask": [False, True],
528
"filter_mask_all_true": True,
529
}
530
)
531
532
# Unnest/explode is applied at the dataframe level, sort is applied afterward
533
assert_frame_equal(
534
df.sql("SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key"),
535
pl.Series("list", [4, 5, 6, 1, 2, 3]).to_frame(),
536
)
537
538
# No NULLS: since order is applied after explode on the dataframe level
539
assert_frame_equal(
540
df.sql(
541
"SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key NULLS FIRST"
542
),
543
pl.Series("list", [4, 5, 6, 1, 2, 3]).to_frame(),
544
)
545
546
# Literals are broadcasted to output height of UNNEST:
547
assert_frame_equal(
548
df.sql("SELECT UNNEST(list_long) as list, 1 as x FROM self ORDER BY sort_key"),
549
pl.select(pl.Series("list", [4, 5, 6, 1, 2, 3]), x=1),
550
)
551
552
# Note: Filter applies before projections in SQL
553
assert_frame_equal(
554
df.sql(
555
"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask ORDER BY sort_key"
556
),
557
pl.Series("list", [4, 5, 6]).to_frame(),
558
)
559
560
assert_frame_equal(
561
df.sql(
562
"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask_all_true ORDER BY sort_key"
563
),
564
pl.Series("list", [4, 5, 6, 1, 2, 3]).to_frame(),
565
)
566
567
568
@pytest.mark.parametrize(
569
("query", "result"),
570
[
571
(
572
"""SELECT a, COUNT(*) OVER (PARTITION BY a) AS b FROM self""",
573
[3, 3, 3, 1, 3, 3, 3],
574
),
575
(
576
"""SELECT a, COUNT() OVER (PARTITION BY a) AS b FROM self""",
577
[3, 3, 3, 1, 3, 3, 3],
578
),
579
(
580
"""SELECT a, COUNT(i) OVER (PARTITION BY a) AS b FROM self""",
581
[3, 3, 3, 1, 1, 1, 1],
582
),
583
(
584
"""SELECT a, COUNT(DISTINCT i) OVER (PARTITION BY a) AS b FROM self""",
585
[2, 2, 2, 1, 1, 1, 1],
586
),
587
],
588
)
589
def test_count_partition_22665(query: str, result: list[Any]) -> None:
590
df = pl.DataFrame(
591
{
592
"a": [1, 1, 1, 2, 3, 3, 3],
593
"i": [0, 0, 1, 2, 3, None, None],
594
}
595
)
596
out = df.sql(query).select("b")
597
expected = pl.DataFrame({"b": result}).cast({"b": pl.get_index_type()})
598
assert_frame_equal(out, expected)
599
600
601
@pytest.mark.parametrize(
602
"query",
603
[
604
# ClickHouse-specific PREWHERE clause
605
"SELECT x, y FROM df PREWHERE z IS NOT NULL",
606
# LATERAL VIEW syntax
607
"SELECT * FROM person LATERAL VIEW EXPLODE(ARRAY(0,125)) tableName AS age",
608
# Oracle-style hierarchical queries
609
"""
610
SELECT employee_id, employee_name, manager_id, LEVEL AS hierarchy_level
611
FROM employees
612
START WITH manager_id IS NULL
613
CONNECT BY PRIOR employee_id = manager_id
614
""",
615
],
616
)
617
def test_unsupported_select_clauses(query: str) -> None:
618
# ensure we're actively catching unsupported clauses
619
with (
620
pl.SQLContext() as ctx,
621
pytest.raises(
622
SQLInterfaceError,
623
match=r"not.*supported",
624
),
625
):
626
ctx.execute(query)
627
628