Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_set_ops.py
6939 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.exceptions import SQLInterfaceError
7
from polars.testing import assert_frame_equal
8
9
10
def test_except_intersect() -> None:
11
df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841
12
df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F841
13
14
res_e = pl.sql("SELECT x, y, z FROM df1 EXCEPT SELECT * FROM df2", eager=True)
15
res_i = pl.sql("SELECT * FROM df1 INTERSECT SELECT x, y, z FROM df2", eager=True)
16
17
assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]
18
assert sorted(res_i.rows()) == [(1, 4, 5)]
19
20
res_e = pl.sql("SELECT * FROM df2 EXCEPT TABLE df1", eager=True)
21
res_i = pl.sql(
22
"""
23
SELECT * FROM df2
24
INTERSECT
25
SELECT x::int8, y::int8, z::int8
26
FROM (VALUES (1,2,5),(9,3,5),(1,4,5),(1,4,5)) AS df1(x,y,z)
27
""",
28
eager=True,
29
)
30
assert sorted(res_e.rows()) == [(1, 2, 7), (9, None, 6)]
31
assert sorted(res_i.rows()) == [(1, 4, 5)]
32
33
# check null behaviour of nulls
34
with pl.SQLContext(
35
tbl1=pl.DataFrame({"x": [2, 9, 1], "y": [2, None, 4]}),
36
tbl2=pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4]}),
37
) as ctx:
38
res = ctx.execute("SELECT * FROM tbl1 EXCEPT SELECT * FROM tbl2", eager=True)
39
assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res)
40
41
42
def test_except_intersect_by_name() -> None:
43
df1 = pl.DataFrame( # noqa: F841
44
{
45
"x": [1, 9, 1, 1],
46
"y": [2, 3, 4, 4],
47
"z": [5, 5, 5, 5],
48
}
49
)
50
df2 = pl.DataFrame( # noqa: F841
51
{
52
"y": [2, None, 4],
53
"w": ["?", "!", "%"],
54
"z": [7, 6, 5],
55
"x": [1, 9, 1],
56
}
57
)
58
res_e = pl.sql(
59
"SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2",
60
eager=True,
61
)
62
res_i = pl.sql(
63
"SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2",
64
eager=True,
65
)
66
assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]
67
assert sorted(res_i.rows()) == [(1, 4, 5)]
68
assert res_e.columns == ["x", "y", "z"]
69
assert res_i.columns == ["x", "y", "z"]
70
71
72
@pytest.mark.parametrize(
73
("op", "op_subtype"),
74
[
75
("EXCEPT", "ALL"),
76
("EXCEPT", "ALL BY NAME"),
77
("INTERSECT", "ALL"),
78
("INTERSECT", "ALL BY NAME"),
79
],
80
)
81
def test_except_intersect_all_unsupported(op: str, op_subtype: str) -> None:
82
df1 = pl.DataFrame({"n": [1, 1, 1, 2, 2, 2, 3]}) # noqa: F841
83
df2 = pl.DataFrame({"n": [1, 1, 2, 2]}) # noqa: F841
84
85
with pytest.raises(
86
SQLInterfaceError,
87
match=f"'{op} {op_subtype}' is not supported",
88
):
89
pl.sql(f"SELECT * FROM df1 {op} {op_subtype} SELECT * FROM df2")
90
91
92
def test_update_statement_error() -> None:
93
df_large = pl.DataFrame(
94
{
95
"FQDN": ["c.ORG.na", "a.COM.na"],
96
"NS1": ["ns1.c.org.na", "ns1.d.net.na"],
97
"NS2": ["ns2.c.org.na", "ns2.d.net.na"],
98
"NS3": ["ns3.c.org.na", "ns3.d.net.na"],
99
}
100
)
101
df_small = pl.DataFrame(
102
{
103
"FQDN": ["c.org.na"],
104
"NS1": ["ns1.c.org.na|127.0.0.1"],
105
"NS2": ["ns2.c.org.na|127.0.0.1"],
106
"NS3": ["ns3.c.org.na|127.0.0.1"],
107
}
108
)
109
110
# Create a context and register the tables
111
ctx = pl.SQLContext()
112
ctx.register("large", df_large)
113
ctx.register("small", df_small)
114
115
with pytest.raises(
116
SQLInterfaceError,
117
match="'UPDATE large SET FQDN = u.FQDN, NS1 = u.NS1, NS2 = u.NS2, NS3 = u.NS3 FROM u WHERE large.FQDN = u.FQDN' operation is currently unsupported",
118
):
119
ctx.execute("""
120
WITH u AS (
121
SELECT
122
small.FQDN,
123
small.NS1,
124
small.NS2,
125
small.NS3
126
FROM small
127
INNER JOIN large ON small.FQDN = large.FQDN
128
)
129
UPDATE large
130
SET
131
FQDN = u.FQDN,
132
NS1 = u.NS1,
133
NS2 = u.NS2,
134
NS3 = u.NS3
135
FROM u
136
WHERE large.FQDN = u.FQDN
137
""")
138
139
140
@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"])
141
def test_except_intersect_errors(op: str) -> None:
142
df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841
143
df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F841
144
145
if op != "UNION":
146
with pytest.raises(
147
SQLInterfaceError,
148
match=f"'{op} ALL' is not supported",
149
):
150
pl.sql(f"SELECT * FROM df1 {op} ALL SELECT * FROM df2", eager=False)
151
152
with pytest.raises(
153
SQLInterfaceError,
154
match=f"{op} requires equal number of columns in each table",
155
):
156
pl.sql(f"SELECT x FROM df1 {op} SELECT x, y FROM df2", eager=False)
157
158
159
@pytest.mark.parametrize(
160
("cols1", "cols2", "union_subtype", "expected"),
161
[
162
(
163
["*"],
164
["*"],
165
"",
166
[(1, "zz"), (2, "yy"), (3, "xx")],
167
),
168
(
169
["*"],
170
["frame2.*"],
171
"ALL",
172
[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],
173
),
174
(
175
["frame1.*"],
176
["c1", "c2"],
177
"DISTINCT",
178
[(1, "zz"), (2, "yy"), (3, "xx")],
179
),
180
(
181
["*"],
182
["c2", "c1"],
183
"ALL BY NAME",
184
[(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")],
185
),
186
(
187
["c1", "c2"],
188
["c2", "c1"],
189
"BY NAME",
190
[(1, "zz"), (2, "yy"), (3, "xx")],
191
),
192
pytest.param(
193
["c1", "c2"],
194
["c2", "c1"],
195
"DISTINCT BY NAME",
196
[(1, "zz"), (2, "yy"), (3, "xx")],
197
),
198
],
199
)
200
def test_union(
201
cols1: list[str],
202
cols2: list[str],
203
union_subtype: str,
204
expected: list[tuple[int, str]],
205
) -> None:
206
with pl.SQLContext(
207
frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}),
208
frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}),
209
eager=True,
210
) as ctx:
211
query = f"""
212
SELECT {", ".join(cols1)} FROM frame1
213
UNION {union_subtype}
214
SELECT {", ".join(cols2)} FROM frame2
215
"""
216
assert sorted(ctx.execute(query).rows()) == expected
217
218