Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_literals.py
6939 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, timedelta
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
9
from polars.testing import assert_frame_equal
10
11
12
def test_bit_hex_literals() -> None:
13
with pl.SQLContext(df=None, eager=True) as ctx:
14
out = ctx.execute(
15
"""
16
SELECT *,
17
-- bit strings
18
b'' AS b0,
19
b'1001' AS b1,
20
b'11101011' AS b2,
21
b'1111110100110010' AS b3,
22
-- hex strings
23
x'' AS x0,
24
x'FF' AS x1,
25
x'4142' AS x2,
26
x'DeadBeef' AS x3,
27
FROM df
28
"""
29
)
30
31
assert out.to_dict(as_series=False) == {
32
"b0": [b""],
33
"b1": [b"\t"],
34
"b2": [b"\xeb"],
35
"b3": [b"\xfd2"],
36
"x0": [b""],
37
"x1": [b"\xff"],
38
"x2": [b"AB"],
39
"x3": [b"\xde\xad\xbe\xef"],
40
}
41
42
43
def test_bit_hex_filter() -> None:
44
df = pl.DataFrame(
45
{"bin": [b"\x01", b"\x02", b"\x03", b"\x04"], "val": [9, 8, 7, 6]}
46
)
47
with pl.SQLContext(test=df) as ctx:
48
for two in ("b'10'", "x'02'", "'\x02'", "b'0010'"):
49
out = ctx.execute(f"SELECT val FROM test WHERE bin > {two}", eager=True)
50
assert out.to_series().to_list() == [7, 6]
51
52
53
def test_bit_hex_errors() -> None:
54
with pl.SQLContext(test=None) as ctx:
55
with pytest.raises(
56
SQLSyntaxError,
57
match="bit string literal should contain only 0s and 1s",
58
):
59
ctx.execute("SELECT b'007' FROM test", eager=True)
60
61
with pytest.raises(
62
SQLSyntaxError,
63
match="hex string literal must have an even number of digits",
64
):
65
ctx.execute("SELECT x'00F' FROM test", eager=True)
66
67
with pytest.raises(
68
SQLSyntaxError,
69
match="hex string literal must have an even number of digits",
70
):
71
pl.sql_expr("colx IN (x'FF',x'123')")
72
73
with pytest.raises(
74
SQLInterfaceError,
75
match=r'NationalStringLiteral\("hmmm"\) is not a supported literal',
76
):
77
pl.sql_expr("N'hmmm'")
78
79
80
def test_bit_hex_membership() -> None:
81
df = pl.DataFrame(
82
{
83
"x": [b"\x05", b"\xff", b"\xcc", b"\x0b"],
84
"y": [1, 2, 3, 4],
85
}
86
)
87
# this checks the internal `visit_any_value` codepath
88
for values in (
89
"b'0101', b'1011'",
90
"x'05', x'0b'",
91
):
92
dff = df.filter(pl.sql_expr(f"x IN ({values})"))
93
assert dff["y"].to_list() == [1, 4]
94
95
96
def test_dollar_quoted_literals() -> None:
97
df = pl.sql(
98
"""
99
SELECT
100
$$xyz$$ AS dq1,
101
$q$xyz$q$ AS dq2,
102
$tag$xyz$tag$ AS dq3,
103
$QUOTE$xyz$QUOTE$ AS dq4,
104
"""
105
).collect()
106
assert df.to_dict(as_series=False) == {f"dq{n}": ["xyz"] for n in range(1, 5)}
107
108
df = pl.sql("SELECT $$x$z$$ AS dq").collect()
109
assert df.item() == "x$z"
110
111
112
def test_fixed_intervals() -> None:
113
with pl.SQLContext(df=None, eager=True) as ctx:
114
out = ctx.execute(
115
"""
116
SELECT
117
-- short form with/without spaces
118
INTERVAL '1w2h3m4s' AS i1,
119
INTERVAL '100ms 100us' AS i2,
120
-- long form with/without commas (case-insensitive)
121
INTERVAL '1 week, 2 hours, 3 minutes, 4 seconds' AS i3
122
FROM df
123
"""
124
)
125
expected = pl.DataFrame(
126
{
127
"i1": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],
128
"i2": [timedelta(microseconds=100100)],
129
"i3": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],
130
},
131
).cast(pl.Duration("ns"))
132
133
assert_frame_equal(expected, out)
134
135
# TODO: negative intervals
136
with pytest.raises(
137
SQLInterfaceError,
138
match="minus signs are not yet supported in interval strings; found '-7d'",
139
):
140
ctx.execute("SELECT INTERVAL '-7d' AS one_week_ago FROM df")
141
142
with pytest.raises(
143
SQLSyntaxError,
144
match="unary ops are not valid on interval strings; found -'7d'",
145
):
146
ctx.execute("SELECT INTERVAL -'7d' AS one_week_ago FROM df")
147
148
with pytest.raises(
149
SQLSyntaxError,
150
match="fixed-duration interval cannot contain years, quarters, or months",
151
):
152
ctx.execute("SELECT INTERVAL '1 quarter 1 month' AS q FROM df")
153
154
155
def test_interval_offsets() -> None:
156
df = pl.DataFrame(
157
{
158
"dtm": [
159
datetime(1899, 12, 31, 8),
160
datetime(1999, 6, 8, 10, 30),
161
datetime(2010, 5, 7, 20, 20, 20),
162
],
163
"dt": [
164
date(1950, 4, 10),
165
date(2048, 1, 20),
166
date(2026, 8, 5),
167
],
168
}
169
)
170
171
out = df.sql(
172
"""
173
SELECT
174
dtm + INTERVAL '2 months, 30 minutes' AS dtm_plus_2mo30m,
175
dt + INTERVAL '100 years' AS dt_plus_100y,
176
dt - INTERVAL '1 quarter' AS dt_minus_1q
177
FROM self
178
ORDER BY 1
179
"""
180
)
181
assert out.to_dict(as_series=False) == {
182
"dtm_plus_2mo30m": [
183
datetime(1900, 2, 28, 8, 30),
184
datetime(1999, 8, 8, 11, 0),
185
datetime(2010, 7, 7, 20, 50, 20),
186
],
187
"dt_plus_100y": [
188
date(2050, 4, 10),
189
date(2148, 1, 20),
190
date(2126, 8, 5),
191
],
192
"dt_minus_1q": [
193
date(1950, 1, 10),
194
date(2047, 10, 20),
195
date(2026, 5, 5),
196
],
197
}
198
199
200
@pytest.mark.parametrize(
201
("interval_comparison", "expected_result"),
202
[
203
("INTERVAL '3 days' <= INTERVAL '3 days, 1 microsecond'", True),
204
("INTERVAL '3 days, 1 microsecond' <= INTERVAL '3 days'", False),
205
("INTERVAL '3 months' >= INTERVAL '3 months'", True),
206
("INTERVAL '2 quarters' < INTERVAL '2 quarters'", False),
207
("INTERVAL '2 quarters' > INTERVAL '2 quarters'", False),
208
("INTERVAL '3 years' <=> INTERVAL '3 years'", True),
209
("INTERVAL '3 years' == INTERVAL '1008 weeks'", False),
210
("INTERVAL '8 weeks' != INTERVAL '2 months'", True),
211
("INTERVAL '8 weeks' = INTERVAL '2 months'", False),
212
("INTERVAL '1 year' != INTERVAL '365 days'", True),
213
("INTERVAL '1 year' = INTERVAL '1 year'", True),
214
],
215
)
216
def test_interval_comparisons(interval_comparison: str, expected_result: bool) -> None:
217
with pl.SQLContext() as ctx:
218
res = ctx.execute(f"SELECT {interval_comparison} AS res")
219
assert res.collect().to_dict(as_series=False) == {"res": [expected_result]}
220
221
222
def test_select_literals_no_table() -> None:
223
res = pl.sql("SELECT 1 AS one, '2' AS two, 3.0 AS three", eager=True)
224
assert res.to_dict(as_series=False) == {
225
"one": [1],
226
"two": ["2"],
227
"three": [3.0],
228
}
229
230
231
def test_select_from_table_with_reserved_names() -> None:
232
select = pl.DataFrame({"select": [1, 2, 3], "from": [4, 5, 6]}) # noqa: F841
233
out = pl.sql(
234
"""
235
SELECT "from", "select"
236
FROM "select"
237
WHERE "from" >= 5 AND "select" % 2 != 1
238
""",
239
eager=True,
240
)
241
assert out.rows() == [(5, 2)]
242
243