Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_subqueries.py
8407 views
1
import pytest
2
3
import polars as pl
4
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
5
from polars.testing import assert_frame_equal
6
7
8
@pytest.mark.parametrize(
9
("cols", "join_type", "constraint"),
10
[
11
("x", "INNER", ""),
12
("y", "INNER", ""),
13
("x", "LEFT", "WHERE y IN (0,1,2,3,4,5)"),
14
("y", "LEFT", "WHERE y >= 0"),
15
("df1.*", "FULL", "WHERE y >= 0"),
16
("df2.*", "FULL", "WHERE x >= 0"),
17
("* EXCLUDE y", "LEFT", "WHERE y >= 0"),
18
("* EXCLUDE x", "LEFT", "WHERE x >= 0"),
19
],
20
)
21
def test_from_subquery(cols: str, join_type: str, constraint: str) -> None:
22
df1 = pl.DataFrame({"x": [-1, 0, 3, 1, 2, -1]})
23
df2 = pl.DataFrame({"y": [0, 1, 2, 3]})
24
25
sql = pl.SQLContext(df1=df1, df2=df2)
26
res = sql.execute(
27
query=f"""
28
SELECT {cols} FROM (SELECT * FROM df1) AS df1
29
{join_type} JOIN (SELECT * FROM df2) AS df2
30
ON df1.x = df2.y {constraint}
31
""",
32
eager=True,
33
)
34
assert sorted(res.to_series()) == [0, 1, 2, 3]
35
36
37
@pytest.mark.may_fail_cloud # reason: with_context
38
def test_in_subquery() -> None:
39
df = pl.DataFrame(
40
{
41
"x": [1, 2, 3, 4, 5, 6],
42
"y": [2, 3, 4, 5, 6, 7],
43
}
44
)
45
df_other = pl.DataFrame(
46
{
47
"w": [1, 2, 3, 4, 5, 6],
48
"z": [2, 3, 4, 5, 6, 7],
49
}
50
)
51
df_chars = pl.DataFrame(
52
{
53
"one": ["a", "b", "c", "d", "e", "f"],
54
"two": ["b", "c", "d", "e", "f", "g"],
55
}
56
)
57
58
ctx = pl.SQLContext(df=df, df_other=df_other, df_chars=df_chars)
59
res_same = ctx.execute(
60
query="""
61
SELECT df.x as x
62
FROM df
63
WHERE x IN (SELECT y FROM df)
64
""",
65
eager=True,
66
)
67
df_expected_same = pl.DataFrame({"x": [2, 3, 4, 5, 6]})
68
assert_frame_equal(
69
left=df_expected_same,
70
right=res_same,
71
)
72
73
res_double = ctx.execute(
74
query="""
75
SELECT df.x as x
76
FROM df
77
WHERE x IN (SELECT y FROM df)
78
AND y IN (SELECT w FROM df_other)
79
""",
80
eager=True,
81
)
82
df_expected_double = pl.DataFrame({"x": [2, 3, 4, 5]})
83
assert_frame_equal(
84
left=df_expected_double,
85
right=res_double,
86
)
87
88
res_expressions = ctx.execute(
89
query="""
90
SELECT
91
df.x as x
92
FROM df
93
WHERE x+1 IN (SELECT y FROM df)
94
AND y IN (SELECT w-1 FROM df_other)
95
""",
96
eager=True,
97
)
98
df_expected_expressions = pl.DataFrame({"x": [1, 2, 3, 4]})
99
assert_frame_equal(
100
left=df_expected_expressions,
101
right=res_expressions,
102
)
103
104
res_not_in = ctx.execute(
105
query="""
106
SELECT
107
df.x as x
108
FROM df
109
WHERE x NOT IN (SELECT y-5 FROM df)
110
AND y NOT IN (SELECT w+5 FROM df_other)
111
""",
112
eager=True,
113
)
114
df_not_in = pl.DataFrame({"x": [3, 4]})
115
assert_frame_equal(
116
left=df_not_in,
117
right=res_not_in,
118
)
119
120
res_chars = ctx.execute(
121
query="""
122
SELECT
123
df_chars.one
124
FROM df_chars
125
WHERE one IN (SELECT two FROM df_chars)
126
""",
127
eager=True,
128
)
129
df_expected_chars = pl.DataFrame({"one": ["b", "c", "d", "e", "f"]})
130
assert_frame_equal(
131
left=res_chars,
132
right=df_expected_chars,
133
)
134
135
with pytest.raises(
136
expected_exception=SQLSyntaxError,
137
match="SQL subquery returns more than one column",
138
):
139
ctx.execute(
140
query="""
141
SELECT
142
df_chars.one
143
FROM df_chars
144
WHERE one IN (SELECT one, two FROM df_chars)
145
"""
146
).collect()
147
148
149
def test_subquery_20732() -> None:
150
lf = pl.concat(
151
[
152
pl.LazyFrame([{"id": 1, "s": "a"}]),
153
pl.LazyFrame([{"id": 2, "s": "b"}]),
154
]
155
)
156
res = pl.sql("SELECT * FROM lf WHERE id IN (SELECT MAX(id) FROM lf)", eager=True)
157
assert res.to_dict(as_series=False) == {"id": [2], "s": ["b"]}
158
159
160
def test_unsupported_subquery_comparisons() -> None:
161
"""Test that using = with a subquery gives a helpful error message."""
162
df = pl.DataFrame({"value": [2000, 2000]})
163
164
for op, suggestion in (("=", "IN"), ("!=", "NOT IN")):
165
with pytest.raises(
166
expected_exception=SQLSyntaxError,
167
match=rf"subquery comparisons with '{op}' are not supported; use '{suggestion}' instead",
168
):
169
pl.sql(f"SELECT * FROM df WHERE value {op} (SELECT MAX(e) FROM df)")
170
171
for op in ("<", "<=", ">", ">="):
172
with pytest.raises(
173
expected_exception=SQLSyntaxError,
174
match=rf"subquery comparisons with '{op}' are not supported",
175
):
176
pl.sql(f"SELECT * FROM df WHERE (SELECT MAX(e) FROM df) {op} value")
177
178
with pytest.raises(
179
expected_exception=SQLSyntaxError,
180
match=rf"subquery comparisons with '{op}' are not supported",
181
):
182
pl.sql(f"SELECT * FROM df WHERE value {op} (SELECT MAX(value) FROM df)")
183
184
185
def test_derived_table_without_alias() -> None:
186
df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})
187
188
# basic unaliased subquery
189
with pl.SQLContext(df=df) as ctx:
190
res = ctx.execute("SELECT * FROM (SELECT a, b FROM df) ORDER BY a", eager=True)
191
assert_frame_equal(res, df)
192
193
# set operation without subquery aliases
194
res = ctx.execute(
195
"""
196
SELECT * FROM (
197
SELECT a, b FROM df WHERE a <= 2
198
UNION ALL
199
SELECT a, b FROM df WHERE a > 2
200
)
201
ORDER BY a
202
"""
203
).collect()
204
assert_frame_equal(res, df)
205
206
# unqualified (but unambiguous) column refs from unaliased derived table
207
res = ctx.execute("SELECT a FROM (SELECT a, b FROM df) ORDER BY a", eager=True)
208
assert_frame_equal(res, df.select("a"))
209
210
211
def test_derived_table_alias_errors() -> None:
212
df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})
213
214
# joining on unaliased derived table should raise
215
for join_type in ("INNER", "LEFT", "CROSS"):
216
constraint = "" if join_type == "CROSS" else "ON df.a = a2"
217
with pytest.raises(
218
expected_exception=SQLInterfaceError,
219
match="cannot JOIN on unnamed relation",
220
):
221
pl.sql(
222
query=f"""
223
SELECT * FROM df
224
{join_type} JOIN (SELECT a AS a2 FROM df) {constraint}
225
"""
226
).collect()
227
228
# unaliased derived tables in a join
229
with pytest.raises(
230
expected_exception=SQLInterfaceError,
231
match="cannot JOIN on unnamed relation",
232
):
233
pl.sql(
234
query="""
235
SELECT *
236
FROM (SELECT a FROM df)
237
INNER JOIN (SELECT b FROM df) ON a = b
238
""",
239
).collect()
240
241
# qualified wildcard on nonexistent alias
242
with pytest.raises(
243
expected_exception=SQLInterfaceError,
244
match="no table or struct column named 'sq' found",
245
):
246
pl.sql(
247
query="SELECT sq.* FROM (SELECT a, b FROM df)",
248
eager=True,
249
)
250
251
# qualified column reference on nonexistent alias
252
with pytest.raises(
253
expected_exception=SQLInterfaceError,
254
match="no table or struct column named 'sq' found",
255
):
256
pl.sql(
257
query="SELECT sq.a FROM (SELECT a, b FROM df)",
258
eager=True,
259
)
260
261
# qualified reference in different clauses
262
with pytest.raises(
263
expected_exception=SQLInterfaceError,
264
match="no table or struct column named 'sq' found",
265
):
266
pl.sql(
267
query="SELECT a FROM (SELECT a, b FROM df) WHERE sq.a > 1",
268
eager=True,
269
)
270
271
with pytest.raises(
272
expected_exception=SQLInterfaceError,
273
match="no table or struct column named 'sq' found",
274
):
275
pl.sql(
276
query="SELECT a, COUNT(*) FROM (SELECT a, b FROM df) GROUP BY sq.a",
277
eager=True,
278
)
279
280
with pytest.raises(
281
expected_exception=SQLInterfaceError,
282
match="no table or struct column named 'sq' found",
283
):
284
pl.sql(
285
query="SELECT a FROM (SELECT a, b FROM df) ORDER BY sq.a",
286
eager=True,
287
)
288
289