Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_miscellaneous.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date
4
from pathlib import Path
5
from typing import TYPE_CHECKING, Any
6
7
import pytest
8
9
import polars as pl
10
from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError
11
from polars.testing import assert_frame_equal
12
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder
13
14
if TYPE_CHECKING:
15
from polars.datatypes import DataType
16
17
18
@pytest.fixture
19
def foods_ipc_path() -> Path:
20
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"
21
22
23
def test_any_all() -> None:
24
df = pl.DataFrame( # noqa: F841
25
{
26
"x": [-1, 0, 1, 2, 3, 4],
27
"y": [1, 0, 0, 1, 2, 3],
28
}
29
)
30
res = pl.sql(
31
"""
32
SELECT
33
x >= ALL(df.y) AS "All Geq",
34
x > ALL(df.y) AS "All G",
35
x < ALL(df.y) AS "All L",
36
x <= ALL(df.y) AS "All Leq",
37
x >= ANY(df.y) AS "Any Geq",
38
x > ANY(df.y) AS "Any G",
39
x < ANY(df.y) AS "Any L",
40
x <= ANY(df.y) AS "Any Leq",
41
x == ANY(df.y) AS "Any eq",
42
x != ANY(df.y) AS "Any Neq",
43
FROM df
44
""",
45
).collect()
46
47
assert res.to_dict(as_series=False) == {
48
"All Geq": [0, 0, 0, 0, 1, 1],
49
"All G": [0, 0, 0, 0, 0, 1],
50
"All L": [1, 0, 0, 0, 0, 0],
51
"All Leq": [1, 1, 0, 0, 0, 0],
52
"Any Geq": [0, 1, 1, 1, 1, 1],
53
"Any G": [0, 0, 1, 1, 1, 1],
54
"Any L": [1, 1, 1, 1, 0, 0],
55
"Any Leq": [1, 1, 1, 1, 1, 0],
56
"Any eq": [0, 1, 1, 1, 1, 0],
57
"Any Neq": [1, 0, 0, 0, 0, 1],
58
}
59
60
61
@pytest.mark.parametrize(
62
("data", "schema"),
63
[
64
({"x": [1, 2, 3, 4]}, None),
65
({"x": [9, 8, 7, 6]}, {"x": pl.Int8}),
66
({"x": ["aa", "bb"]}, {"x": pl.Struct}),
67
({"x": [None, None], "y": [None, None]}, {"x": pl.Date, "y": pl.Float64}),
68
],
69
)
70
def test_boolean_where_clauses(
71
data: dict[str, Any], schema: dict[str, DataType] | None
72
) -> None:
73
df = pl.DataFrame(data=data, schema=schema)
74
empty_df = df.clear()
75
76
for true in ("TRUE", "1=1", "2 == 2", "'xx' = 'xx'", "TRUE AND 1=1"):
77
assert_frame_equal(df, df.sql(f"SELECT * FROM self WHERE {true}"))
78
79
for false in ("false", "1!=1", "2 != 2", "'xx' != 'xx'", "FALSE OR 1!=1"):
80
assert_frame_equal(empty_df, df.sql(f"SELECT * FROM self WHERE {false}"))
81
82
83
def test_count() -> None:
84
df = pl.DataFrame(
85
{
86
"a": [1, 2, 3, 4, 5],
87
"b": [1, 1, 22, 22, 333],
88
"c": [1, 1, None, None, 2],
89
}
90
)
91
res = df.sql(
92
"""
93
SELECT
94
-- count
95
COUNT(a) AS count_a,
96
COUNT(b) AS count_b,
97
COUNT(c) AS count_c,
98
COUNT(*) AS count_star,
99
COUNT(NULL) AS count_null,
100
-- count distinct
101
COUNT(DISTINCT a) AS count_unique_a,
102
COUNT(DISTINCT b) AS count_unique_b,
103
COUNT(DISTINCT c) AS count_unique_c,
104
COUNT(DISTINCT NULL) AS count_unique_null,
105
FROM self
106
""",
107
)
108
assert res.to_dict(as_series=False) == {
109
"count_a": [5],
110
"count_b": [5],
111
"count_c": [3],
112
"count_star": [5],
113
"count_null": [0],
114
"count_unique_a": [5],
115
"count_unique_b": [3],
116
"count_unique_c": [2],
117
"count_unique_null": [0],
118
}
119
120
df = pl.DataFrame({"x": [None, None, None]})
121
res = df.sql(
122
"""
123
SELECT
124
COUNT(x) AS count_x,
125
COUNT(*) AS count_star,
126
COUNT(DISTINCT x) AS count_unique_x
127
FROM self
128
"""
129
)
130
assert res.to_dict(as_series=False) == {
131
"count_x": [0],
132
"count_star": [3],
133
"count_unique_x": [0],
134
}
135
136
137
def test_distinct() -> None:
138
df = pl.DataFrame(
139
{
140
"a": [1, 1, 1, 2, 2, 3],
141
"b": [1, 2, 3, 4, 5, 6],
142
}
143
)
144
ctx = pl.SQLContext(register_globals=True, eager=True)
145
res1 = ctx.execute("SELECT DISTINCT a FROM df ORDER BY a DESC")
146
assert_frame_equal(
147
left=df.select("a").unique().sort(by="a", descending=True),
148
right=res1,
149
)
150
151
res2 = ctx.execute(
152
"""
153
SELECT DISTINCT
154
a * 2 AS two_a,
155
b / 2 AS half_b
156
FROM df
157
ORDER BY two_a ASC, half_b DESC
158
""",
159
)
160
assert res2.to_dict(as_series=False) == {
161
"two_a": [2, 2, 4, 6],
162
"half_b": [1, 0, 2, 3],
163
}
164
165
# test unregistration
166
ctx.unregister("df")
167
with pytest.raises(SQLInterfaceError, match="relation 'df' was not found"):
168
ctx.execute("SELECT * FROM df")
169
170
171
def test_frame_sql_globals_error() -> None:
172
df1 = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
173
df2 = pl.DataFrame({"a": [2, 3, 4], "b": [7, 6, 5]}) # noqa: F841
174
175
query = """
176
SELECT df1.a, df2.b
177
FROM df2 JOIN df1 ON df1.a = df2.a
178
ORDER BY b DESC
179
"""
180
with pytest.raises(SQLInterfaceError, match="relation.*not found.*"):
181
df1.sql(query=query)
182
183
res = pl.sql(query=query, eager=True)
184
assert res.to_dict(as_series=False) == {"a": [2, 3], "b": [7, 6]}
185
186
187
def test_in_no_ops_11946() -> None:
188
lf = pl.LazyFrame(
189
[
190
{"i1": 1},
191
{"i1": 2},
192
{"i1": 3},
193
]
194
)
195
out = lf.sql(
196
query="SELECT * FROM frame_data WHERE i1 in (1, 3)",
197
table_name="frame_data",
198
).collect()
199
assert out.to_dict(as_series=False) == {"i1": [1, 3]}
200
201
202
def test_limit_offset() -> None:
203
n_values = 11
204
lf = pl.LazyFrame({"a": range(n_values), "b": reversed(range(n_values))})
205
ctx = pl.SQLContext(tbl=lf)
206
207
assert ctx.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [
208
(4, 6),
209
(5, 5),
210
(6, 4),
211
]
212
for offset, limit in [(0, 3), (1, n_values), (2, 3), (5, 3), (8, 5), (n_values, 1)]:
213
out = ctx.execute(
214
f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True
215
)
216
assert_frame_equal(out, lf.slice(offset, limit).collect())
217
assert len(out) == min(limit, n_values - offset)
218
219
220
def test_register_context() -> None:
221
# use as context manager unregisters tables created within each scope
222
# on exit from that scope; arbitrary levels of nesting are supported.
223
with pl.SQLContext() as ctx:
224
_lf1 = pl.LazyFrame({"a": [1, 2, 3], "b": ["m", "n", "o"]})
225
_lf2 = pl.LazyFrame({"a": [2, 3, 4], "c": ["p", "q", "r"]})
226
ctx.register_globals()
227
assert ctx.tables() == ["_lf1", "_lf2"]
228
229
with ctx:
230
_lf3 = pl.LazyFrame({"a": [3, 4, 5], "b": ["s", "t", "u"]})
231
_lf4 = pl.LazyFrame({"a": [4, 5, 6], "c": ["v", "w", "x"]})
232
ctx.register_globals(n=2)
233
assert ctx.tables() == ["_lf1", "_lf2", "_lf3", "_lf4"]
234
235
assert ctx.tables() == ["_lf1", "_lf2"]
236
237
assert ctx.tables() == []
238
239
240
def test_sql_on_compatible_frame_types() -> None:
241
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
242
243
# create various different frame types
244
dfp = df.to_pandas()
245
dfa = df.to_arrow()
246
dfb = dfa.to_batches()[0] # noqa: F841
247
dfo = PyCapsuleStreamHolder(df) # noqa: F841
248
249
# run polars sql query against all frame types
250
for dfs in ( # noqa: B007
251
(df["a"] * 2).rename("c"), # polars series
252
(dfp["a"] * 2).rename("c"), # pandas series
253
):
254
res = pl.sql(
255
"""
256
SELECT a, b, SUM(c) AS cc FROM (
257
SELECT * FROM df -- polars frame
258
UNION ALL SELECT * FROM dfp -- pandas frame
259
UNION ALL SELECT * FROM dfa -- pyarrow table
260
UNION ALL SELECT * FROM dfb -- pyarrow record batch
261
UNION ALL SELECT * FROM dfo -- arbitrary pycapsule object
262
) tbl
263
INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series
264
GROUP BY "a", "b"
265
ORDER BY "a", "b"
266
"""
267
).collect()
268
269
expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [20, 30]})
270
assert_frame_equal(left=expected, right=res)
271
272
# register and operate on non-polars frames
273
for obj in (dfa, dfp):
274
with pl.SQLContext(obj=obj) as ctx:
275
res = ctx.execute("SELECT * FROM obj", eager=True)
276
assert_frame_equal(df, res)
277
278
# don't register all compatible objects
279
with pytest.raises(SQLInterfaceError, match="relation 'dfp' was not found"):
280
pl.SQLContext(register_globals=True).execute("SELECT * FROM dfp")
281
282
283
def test_nested_cte_column_aliasing() -> None:
284
# trace through nested CTEs with multiple levels of column & table aliasing
285
df = pl.sql(
286
"""
287
WITH
288
x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)),
289
y (m, n) AS (
290
WITH z(c, d) AS (SELECT a, b FROM x)
291
SELECT d*2 AS d2, c*3 AS c3 FROM z
292
)
293
SELECT n, m FROM y
294
""",
295
eager=True,
296
)
297
assert df.to_dict(as_series=False) == {
298
"n": [3, 9],
299
"m": [4, 8],
300
}
301
302
303
def test_invalid_derived_table_column_aliases() -> None:
304
values_query = "SELECT * FROM (VALUES (1,2), (3,4))"
305
306
with pytest.raises(
307
SQLSyntaxError,
308
match=r"columns \(5\) in alias 'tbl' does not match .* the table/query \(2\)",
309
):
310
pl.sql(f"{values_query} AS tbl(a, b, c, d, e)")
311
312
assert pl.sql(f"{values_query} tbl", eager=True).rows() == [(1, 2), (3, 4)]
313
314
315
def test_values_clause_table_registration() -> None:
316
with pl.SQLContext(frames=None, eager=True) as ctx:
317
# initially no tables are registered
318
assert ctx.tables() == []
319
320
# confirm that VALUES clause derived table is registered, post-query
321
res1 = ctx.execute("SELECT * FROM (VALUES (-1,1)) AS tbl(x, y)")
322
assert ctx.tables() == ["tbl"]
323
324
# and confirm that we can select from it by the registered name
325
res2 = ctx.execute("SELECT x, y FROM tbl")
326
for res in (res1, res2):
327
assert res.to_dict(as_series=False) == {"x": [-1], "y": [1]}
328
329
330
def test_read_csv(tmp_path: Path) -> None:
331
# check empty string vs null, parsing of dates, etc
332
df = pl.DataFrame(
333
{
334
"label": ["lorem", None, "", "ipsum"],
335
"num": [-1, None, 0, 1],
336
"dt": [
337
date(1969, 7, 5),
338
date(1999, 12, 31),
339
date(2077, 10, 10),
340
None,
341
],
342
}
343
)
344
csv_target = tmp_path / "test_sql_read.csv"
345
df.write_csv(csv_target)
346
347
res = pl.sql(f"SELECT * FROM read_csv('{csv_target}')").collect()
348
assert_frame_equal(df, res)
349
350
with pytest.raises(
351
SQLSyntaxError,
352
match="`read_csv` expects a single file path; found 3 arguments",
353
):
354
pl.sql("SELECT * FROM read_csv('a','b','c')")
355
356
357
def test_global_variable_inference_17398() -> None:
358
users = pl.DataFrame({"id": "1"})
359
360
res = pl.sql(
361
query="""
362
WITH user_by_email AS (SELECT id FROM users)
363
SELECT * FROM user_by_email
364
""",
365
eager=True,
366
)
367
assert_frame_equal(res, users)
368
369
370
@pytest.mark.parametrize(
371
"query",
372
[
373
"SELECT invalid_column FROM self",
374
"SELECT key, invalid_column FROM self",
375
"SELECT invalid_column * 2 FROM self",
376
"SELECT * FROM self ORDER BY invalid_column",
377
"SELECT * FROM self WHERE invalid_column = 200",
378
"SELECT * FROM self WHERE invalid_column = '200'",
379
"SELECT key, SUM(n) AS sum_n FROM self GROUP BY invalid_column",
380
],
381
)
382
def test_invalid_cols(query: str) -> None:
383
df = pl.DataFrame(
384
{
385
"key": ["xx", "xx", "yy"],
386
"n": ["100", "200", "300"],
387
}
388
)
389
with pytest.raises(ColumnNotFoundError, match="invalid_column"):
390
df.sql(query)
391
392
393
@pytest.mark.parametrize("filter_expr", ["", "WHERE 1 = 1", "WHERE a == 1 OR a != 1"])
394
@pytest.mark.parametrize("order_expr", ["", "ORDER BY 1", "ORDER BY a"])
395
def test_select_output_heights_20058_21084(filter_expr: str, order_expr: str) -> None:
396
df = pl.DataFrame({"a": [1, 2, 3]})
397
398
# Queries that maintain original height
399
400
assert_frame_equal(
401
df.sql(f"SELECT 1 as a FROM self {filter_expr} {order_expr}").cast(pl.Int64),
402
pl.select(a=pl.Series([1, 1, 1])),
403
)
404
405
assert_frame_equal(
406
df.sql(f"SELECT 1 + 1 as a, 1 as b FROM self {filter_expr} {order_expr}").cast(
407
pl.Int64
408
),
409
pl.DataFrame({"a": [2, 2, 2], "b": [1, 1, 1]}),
410
)
411
412
# Queries that aggregate to unit height
413
414
assert_frame_equal(
415
df.sql(f"SELECT COUNT(*) as a FROM self {filter_expr} {order_expr}").cast(
416
pl.Int64
417
),
418
pl.DataFrame({"a": 3}),
419
)
420
421
assert_frame_equal(
422
df.sql(
423
f"SELECT COUNT(*) as a, 1 as b FROM self {filter_expr} {order_expr}"
424
).cast(pl.Int64),
425
pl.DataFrame({"a": 3, "b": 1}),
426
)
427
428
assert_frame_equal(
429
df.sql(
430
f"SELECT FIRST(a) as a, 1 as b FROM self {filter_expr} {order_expr}"
431
).cast(pl.Int64),
432
pl.DataFrame({"a": 1, "b": 1}),
433
)
434
435
assert_frame_equal(
436
df.sql(f"SELECT SUM(a) as a, 1 as b FROM self {filter_expr} {order_expr}").cast(
437
pl.Int64
438
),
439
pl.DataFrame({"a": 6, "b": 1}),
440
)
441
442
assert_frame_equal(
443
df.sql(
444
f"SELECT FIRST(1) as a, 1 as b FROM self {filter_expr} {order_expr}"
445
).cast(pl.Int64),
446
pl.DataFrame({"a": 1, "b": 1}),
447
)
448
449
assert_frame_equal(
450
df.sql(
451
f"SELECT FIRST(1) + 1 as a, 1 as b FROM self {filter_expr} {order_expr}"
452
).cast(pl.Int64),
453
pl.DataFrame({"a": 2, "b": 1}),
454
)
455
456
assert_frame_equal(
457
df.sql(
458
f"SELECT FIRST(1 + 1) as a, 1 as b FROM self {filter_expr} {order_expr}"
459
).cast(pl.Int64),
460
pl.DataFrame({"a": 2, "b": 1}),
461
)
462
463
464
def test_select_explode_height_filter_order_by() -> None:
465
# Note: `unnest()` from SQL equates to `Expr.explode()`
466
df = pl.DataFrame(
467
{
468
"list_long": [[1, 2, 3], [4, 5, 6]],
469
"sort_key": [2, 1],
470
"filter_mask": [False, True],
471
"filter_mask_all_true": True,
472
}
473
)
474
475
# Height of unnest is larger than height of sort_key, the sort_key is
476
# extended with NULLs.
477
478
assert_frame_equal(
479
df.sql("SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key"),
480
pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(),
481
)
482
483
assert_frame_equal(
484
df.sql(
485
"SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key NULLS FIRST"
486
),
487
pl.Series("list", [3, 4, 5, 6, 2, 1]).to_frame(),
488
)
489
490
# Literals are broadcasted to output height of UNNEST:
491
assert_frame_equal(
492
df.sql("SELECT UNNEST(list_long) as list, 1 as x FROM self ORDER BY sort_key"),
493
pl.select(pl.Series("list", [2, 1, 3, 4, 5, 6]), x=1),
494
)
495
496
# Note: Filter applies before projections in SQL
497
assert_frame_equal(
498
df.sql(
499
"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask ORDER BY sort_key"
500
),
501
pl.Series("list", [4, 5, 6]).to_frame(),
502
)
503
504
assert_frame_equal(
505
df.sql(
506
"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask_all_true ORDER BY sort_key"
507
),
508
pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(),
509
)
510
511
512
@pytest.mark.parametrize(
513
("query", "result"),
514
[
515
(
516
"""SELECT a, COUNT(*) OVER (PARTITION BY a) AS b FROM self""",
517
[3, 3, 3, 1, 3, 3, 3],
518
),
519
(
520
"""SELECT a, COUNT() OVER (PARTITION BY a) AS b FROM self""",
521
[3, 3, 3, 1, 3, 3, 3],
522
),
523
(
524
"""SELECT a, COUNT(i) OVER (PARTITION BY a) AS b FROM self""",
525
[3, 3, 3, 1, 1, 1, 1],
526
),
527
(
528
"""SELECT a, COUNT(DISTINCT i) OVER (PARTITION BY a) AS b FROM self""",
529
[2, 2, 2, 1, 1, 1, 1],
530
),
531
],
532
)
533
def test_count_partition_22665(query: str, result: list[Any]) -> None:
534
df = pl.DataFrame(
535
{
536
"a": [1, 1, 1, 2, 3, 3, 3],
537
"i": [0, 0, 1, 2, 3, None, None],
538
}
539
)
540
out = df.sql(query).select("b")
541
expected = pl.DataFrame({"b": result}).cast({"b": pl.UInt32})
542
assert_frame_equal(out, expected)
543
544