Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_literals.py
8406 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, timedelta
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
9
from polars.testing import assert_frame_equal
10
from tests.unit.sql import assert_sql_matches
11
12
13
def test_bit_hex_literals() -> None:
14
with pl.SQLContext(df=None, eager=True) as ctx:
15
out = ctx.execute(
16
"""
17
SELECT *,
18
-- bit strings
19
b'' AS b0,
20
b'1001' AS b1,
21
b'11101011' AS b2,
22
b'1111110100110010' AS b3,
23
-- hex strings
24
x'' AS x0,
25
x'FF' AS x1,
26
x'4142' AS x2,
27
x'DeadBeef' AS x3,
28
FROM df
29
"""
30
)
31
32
assert out.to_dict(as_series=False) == {
33
"b0": [b""],
34
"b1": [b"\t"],
35
"b2": [b"\xeb"],
36
"b3": [b"\xfd2"],
37
"x0": [b""],
38
"x1": [b"\xff"],
39
"x2": [b"AB"],
40
"x3": [b"\xde\xad\xbe\xef"],
41
}
42
43
44
def test_bit_hex_filter() -> None:
45
df = pl.DataFrame(
46
{"bin": [b"\x01", b"\x02", b"\x03", b"\x04"], "val": [9, 8, 7, 6]}
47
)
48
with pl.SQLContext(test=df) as ctx:
49
for two in ("b'10'", "x'02'", "'\x02'", "b'0010'"):
50
out = ctx.execute(f"SELECT val FROM test WHERE bin > {two}", eager=True)
51
assert out.to_series().to_list() == [7, 6]
52
53
54
def test_bit_hex_errors() -> None:
55
with pl.SQLContext(test=None) as ctx:
56
with pytest.raises(
57
SQLSyntaxError,
58
match="bit string literal should contain only 0s and 1s",
59
):
60
ctx.execute("SELECT b'007' FROM test", eager=True)
61
62
with pytest.raises(
63
SQLSyntaxError,
64
match="hex string literal must have an even number of digits",
65
):
66
ctx.execute("SELECT x'00F' FROM test", eager=True)
67
68
with pytest.raises(
69
SQLSyntaxError,
70
match="hex string literal must have an even number of digits",
71
):
72
pl.sql_expr("colx IN (x'FF',x'123')")
73
74
with pytest.raises(
75
SQLInterfaceError,
76
match=r'NationalStringLiteral\("hmmm"\) is not a supported literal',
77
):
78
pl.sql_expr("N'hmmm'")
79
80
81
def test_bit_hex_membership() -> None:
82
df = pl.DataFrame(
83
{
84
"x": [b"\x05", b"\xff", b"\xcc", b"\x0b"],
85
"y": [1, 2, 3, 4],
86
}
87
)
88
# this checks the internal `visit_any_value` codepath
89
for values in (
90
"b'0101', b'1011'",
91
"x'05', x'0b'",
92
):
93
dff = df.filter(pl.sql_expr(f"x IN ({values})"))
94
assert dff["y"].to_list() == [1, 4]
95
96
97
def test_dollar_quoted_literals() -> None:
98
df = pl.sql(
99
"""
100
SELECT
101
$$xyz$$ AS dq1,
102
$q$xyz$q$ AS dq2,
103
$tag$xyz$tag$ AS dq3,
104
$QUOTE$xyz$QUOTE$ AS dq4,
105
"""
106
).collect()
107
assert df.to_dict(as_series=False) == {f"dq{n}": ["xyz"] for n in range(1, 5)}
108
109
df = pl.sql("SELECT $$x$z$$ AS dq").collect()
110
assert df.item() == "x$z"
111
112
113
def test_fixed_intervals() -> None:
114
with pl.SQLContext(df=None, eager=True) as ctx:
115
out = ctx.execute(
116
"""
117
SELECT
118
-- short form with/without spaces
119
INTERVAL '1w2h3m4s' AS i1,
120
INTERVAL '100ms 100us' AS i2,
121
-- long form with/without commas (case-insensitive)
122
INTERVAL '1 week, 2 hours, 3 minutes, 4 seconds' AS i3
123
FROM df
124
"""
125
)
126
expected = pl.DataFrame(
127
{
128
"i1": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],
129
"i2": [timedelta(microseconds=100100)],
130
"i3": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],
131
},
132
).cast(pl.Duration("ns"))
133
134
assert_frame_equal(expected, out)
135
136
# TODO: negative intervals
137
with pytest.raises(
138
SQLInterfaceError,
139
match="minus signs are not yet supported in interval strings; found '-7d'",
140
):
141
ctx.execute("SELECT INTERVAL '-7d' AS one_week_ago FROM df")
142
143
with pytest.raises(
144
SQLSyntaxError,
145
match="unary ops are not valid on interval strings; found -'7d'",
146
):
147
ctx.execute("SELECT INTERVAL -'7d' AS one_week_ago FROM df")
148
149
with pytest.raises(
150
SQLSyntaxError,
151
match="fixed-duration interval cannot contain years, quarters, or months",
152
):
153
ctx.execute("SELECT INTERVAL '1 quarter 1 month' AS q FROM df")
154
155
156
def test_interval_offsets() -> None:
157
df = pl.DataFrame(
158
{
159
"dtm": [
160
datetime(1899, 12, 31, 8),
161
datetime(1999, 6, 8, 10, 30),
162
datetime(2010, 5, 7, 20, 20, 20),
163
],
164
"dt": [
165
date(1950, 4, 10),
166
date(2048, 1, 20),
167
date(2026, 8, 5),
168
],
169
}
170
)
171
172
out = df.sql(
173
"""
174
SELECT
175
dtm + INTERVAL '2 months, 30 minutes' AS dtm_plus_2mo30m,
176
dt + INTERVAL '100 years' AS dt_plus_100y,
177
dt - INTERVAL '1 quarter' AS dt_minus_1q
178
FROM self
179
ORDER BY 1
180
"""
181
)
182
assert out.to_dict(as_series=False) == {
183
"dtm_plus_2mo30m": [
184
datetime(1900, 2, 28, 8, 30),
185
datetime(1999, 8, 8, 11, 0),
186
datetime(2010, 7, 7, 20, 50, 20),
187
],
188
"dt_plus_100y": [
189
date(2050, 4, 10),
190
date(2148, 1, 20),
191
date(2126, 8, 5),
192
],
193
"dt_minus_1q": [
194
date(1950, 1, 10),
195
date(2047, 10, 20),
196
date(2026, 5, 5),
197
],
198
}
199
200
201
@pytest.mark.parametrize(
202
("interval_comparison", "expected_result"),
203
[
204
("INTERVAL '3 days' <= INTERVAL '3 days, 1 microsecond'", True),
205
("INTERVAL '3 days, 1 microsecond' <= INTERVAL '3 days'", False),
206
("INTERVAL '3 months' >= INTERVAL '3 months'", True),
207
("INTERVAL '2 quarters' < INTERVAL '2 quarters'", False),
208
("INTERVAL '2 quarters' > INTERVAL '2 quarters'", False),
209
("INTERVAL '3 years' <=> INTERVAL '3 years'", True),
210
("INTERVAL '3 years' == INTERVAL '1008 weeks'", False),
211
("INTERVAL '8 weeks' != INTERVAL '2 months'", True),
212
("INTERVAL '8 weeks' = INTERVAL '2 months'", False),
213
("INTERVAL '1 year' != INTERVAL '365 days'", True),
214
("INTERVAL '1 year' = INTERVAL '1 year'", True),
215
],
216
)
217
def test_interval_comparisons(interval_comparison: str, expected_result: bool) -> None:
218
with pl.SQLContext() as ctx:
219
res = ctx.execute(f"SELECT {interval_comparison} AS res")
220
assert res.collect().to_dict(as_series=False) == {"res": [expected_result]}
221
222
223
def test_select_literals_no_table() -> None:
224
res = pl.sql("SELECT 1 AS one, '2' AS two, 3.0 AS three", eager=True)
225
assert res.to_dict(as_series=False) == {
226
"one": [1],
227
"two": ["2"],
228
"three": [3.0],
229
}
230
231
232
def test_literal_only_select() -> None:
233
"""Check that literal-only SELECT broadcasts to the source table's height."""
234
df = pl.DataFrame({"x": [1, 2, 3], "y": [4.0, 5.0, 6.0]})
235
236
assert_sql_matches(
237
df,
238
query="SELECT 1 AS one, 2.5 AS two FROM self",
239
expected={"one": [1, 1, 1], "two": [2.5, 2.5, 2.5]},
240
compare_with="sqlite",
241
)
242
assert_sql_matches(
243
df,
244
query="SELECT 1 + 2 AS sum, 'abc' || 'def' AS concat FROM self",
245
expected={"sum": [3, 3, 3], "concat": ["abcdef", "abcdef", "abcdef"]},
246
compare_with="sqlite",
247
)
248
249
# empty table should result in zero rows
250
df = df.clear()
251
252
assert_sql_matches(
253
df,
254
query="SELECT 42 AS the_answer, 'test' AS str FROM self",
255
expected={"the_answer": [], "str": []},
256
compare_with="sqlite",
257
)
258
259
260
def test_literal_only_select_distinct() -> None:
261
"""Test literal-only SELECT with DISTINCT clause."""
262
df = pl.DataFrame({"x": [1, 2, 3, 4, 5]})
263
264
# DISTINCT on broadcast literals should collapse to 1 row
265
assert_sql_matches(
266
df,
267
query="SELECT DISTINCT 42 AS val FROM self",
268
expected={"val": [42]},
269
compare_with="sqlite",
270
)
271
272
273
def test_literal_only_select_order_by() -> None:
274
"""Test literal-only SELECT with ORDER BY (edge case: no-op but shouldn't error)."""
275
df = pl.DataFrame({"x": [3, 1, 2]})
276
277
# ORDER BY on literal column is a no-op but should still work
278
assert_sql_matches(
279
df,
280
query="SELECT 1 AS one FROM self ORDER BY one",
281
expected={"one": [1, 1, 1]},
282
compare_with="sqlite",
283
)
284
285
286
def test_literal_only_select_where() -> None:
287
"""Test literal-only SELECT respects WHERE filtering."""
288
df = pl.DataFrame({"x": [1, 2, 3, 4, 5]})
289
290
# WHERE clause should filter, then literals broadcast to the filtered height
291
assert_sql_matches(
292
df,
293
query="SELECT 99 AS lit FROM self WHERE x > 3",
294
expected={"lit": [99, 99]},
295
compare_with="sqlite",
296
)
297
assert_sql_matches(
298
df,
299
query="SELECT 99 AS lit FROM self WHERE x > 100000000",
300
expected={"lit": []},
301
compare_with="sqlite",
302
)
303
304
305
def test_literal_only_select_limit() -> None:
306
"""Test literal-only SELECT with LIMIT clause."""
307
df = pl.DataFrame({"x": list(range(10))})
308
309
assert_sql_matches(
310
df,
311
query="SELECT 'val' AS s FROM self LIMIT 3",
312
expected={"s": ["val", "val", "val"]},
313
compare_with="sqlite",
314
)
315
316
317
def test_literal_only_select_nested_expressions() -> None:
318
"""Test literal-only SELECT with complex nested expressions (no column refs)."""
319
df = pl.DataFrame({"x": [1, 2]})
320
321
assert_sql_matches(
322
df,
323
query="""
324
SELECT
325
CASE WHEN 1 > 0 THEN 'yes' ELSE 'no' END AS cond,
326
COALESCE(NULL, 'fallback') AS coal,
327
ABS(-5) AS absval
328
FROM self
329
""",
330
expected={
331
"cond": ["yes", "yes"],
332
"coal": ["fallback", "fallback"],
333
"absval": [5, 5],
334
},
335
compare_with="sqlite",
336
)
337
338
339
def test_mixed_literal_and_column() -> None:
340
"""Test basic mixed literal/column SELECT."""
341
df = pl.DataFrame({"x": [10, 20, 30]})
342
343
# When there's at least one column reference, normal behavior applies
344
assert_sql_matches(
345
df,
346
query="SELECT x, 99 AS lit FROM self",
347
expected={"x": [10, 20, 30], "lit": [99, 99, 99]},
348
compare_with="sqlite",
349
)
350
351
352
def test_select_from_table_with_reserved_names() -> None:
353
select = pl.DataFrame({"select": [1, 2, 3], "from": [4, 5, 6]})
354
out = pl.sql(
355
query="""
356
SELECT "from", "select"
357
FROM "select"
358
WHERE "from" >= 5 AND "select" % 2 != 1
359
""",
360
eager=True,
361
)
362
assert out.rows() == [(5, 2)]
363
364