Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_array.py
6939 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
9
from polars.testing import assert_frame_equal
10
11
12
@pytest.mark.parametrize(
13
("sort_order", "limit", "expected"),
14
[
15
(None, None, [("a", ["x", "y"]), ("b", ["z", "X", "Y"])]),
16
("ASC", None, [("a", ["x", "y"]), ("b", ["z", "Y", "X"])]),
17
("DESC", None, [("a", ["y", "x"]), ("b", ["X", "Y", "z"])]),
18
("ASC", 2, [("a", ["x", "y"]), ("b", ["z", "Y"])]),
19
("DESC", 2, [("a", ["y", "x"]), ("b", ["X", "Y"])]),
20
("ASC", 1, [("a", ["x"]), ("b", ["z"])]),
21
("DESC", 1, [("a", ["y"]), ("b", ["X"])]),
22
],
23
)
24
def test_array_agg(sort_order: str | None, limit: int | None, expected: Any) -> None:
25
order_by = "" if not sort_order else f" ORDER BY col0 {sort_order}"
26
limit_clause = "" if not limit else f" LIMIT {limit}"
27
28
res = pl.sql(
29
f"""
30
WITH data (col0, col1, col2) as (
31
VALUES
32
(1,'a','x'),
33
(2,'a','y'),
34
(4,'b','z'),
35
(8,'b','X'),
36
(7,'b','Y')
37
)
38
SELECT col1, ARRAY_AGG(col2{order_by}{limit_clause}) AS arrs
39
FROM data
40
GROUP BY col1
41
ORDER BY col1
42
"""
43
).collect()
44
45
assert res.rows() == expected
46
47
48
def test_array_literals() -> None:
49
with pl.SQLContext(df=None, eager=True) as ctx:
50
res = ctx.execute(
51
"""
52
SELECT
53
a1, a2,
54
-- test some array ops
55
ARRAY_AGG(a1) AS a3,
56
ARRAY_AGG(a2) AS a4,
57
ARRAY_CONTAINS(a1,20) AS i20,
58
ARRAY_CONTAINS(a2,'zz') AS izz,
59
ARRAY_REVERSE(a1) AS ar1,
60
ARRAY_REVERSE(a2) AS ar2
61
FROM (
62
SELECT
63
-- declare array literals
64
[10,20,30] AS a1,
65
['a','b','c'] AS a2,
66
FROM df
67
) tbl
68
"""
69
)
70
assert_frame_equal(
71
res,
72
pl.DataFrame(
73
{
74
"a1": [[10, 20, 30]],
75
"a2": [["a", "b", "c"]],
76
"a3": [[[10, 20, 30]]],
77
"a4": [[["a", "b", "c"]]],
78
"i20": [True],
79
"izz": [False],
80
"ar1": [[30, 20, 10]],
81
"ar2": [["c", "b", "a"]],
82
}
83
),
84
)
85
86
87
@pytest.mark.parametrize(
88
("array_index", "expected"),
89
[
90
(-4, None),
91
(-3, 99),
92
(-2, 66),
93
(-1, 33),
94
(0, None),
95
(1, 99),
96
(2, 66),
97
(3, 33),
98
(4, None),
99
],
100
)
101
def test_array_indexing(array_index: int, expected: int | None) -> None:
102
res = pl.sql(
103
f"""
104
SELECT
105
arr[{array_index}] AS idx1,
106
ARRAY_GET(arr,{array_index}) AS idx2,
107
FROM (SELECT [99,66,33] AS arr) tbl
108
"""
109
).collect()
110
111
assert_frame_equal(
112
res,
113
pl.DataFrame(
114
{"idx1": [expected], "idx2": [expected]},
115
),
116
check_dtypes=False,
117
)
118
119
120
def test_array_indexing_by_expr() -> None:
121
df = pl.DataFrame(
122
{
123
"idx": [-2, -1, 0, None, 1, 2, 3],
124
"arr": [[0, 1, 2, 3], [4, 5], [6], [7, 8, 9], [8, 7], [6, 5, 4], [3, 2, 1]],
125
}
126
)
127
res = df.sql(
128
"""
129
SELECT
130
arr[idx] AS idx1,
131
ARRAY_GET(arr, idx) AS idx2
132
FROM self
133
"""
134
)
135
expected = [2, 5, None, None, 8, 5, 1]
136
assert_frame_equal(res, pl.DataFrame({"idx1": expected, "idx2": expected}))
137
138
139
def test_array_to_string() -> None:
140
data = {
141
"s_values": [["aa", "bb"], [None, "cc"], ["dd", None]],
142
"n_values": [[999, 777], [None, 555], [333, None]],
143
}
144
res = pl.DataFrame(data).sql(
145
"""
146
SELECT
147
ARRAY_TO_STRING(s_values, '') AS vs1,
148
ARRAY_TO_STRING(s_values, ':') AS vs2,
149
ARRAY_TO_STRING(s_values, ':', 'NA') AS vs3,
150
ARRAY_TO_STRING(n_values, '') AS vn1,
151
ARRAY_TO_STRING(n_values, ':') AS vn2,
152
ARRAY_TO_STRING(n_values, ':', 'NA') AS vn3
153
FROM self
154
"""
155
)
156
assert_frame_equal(
157
res,
158
pl.DataFrame(
159
{
160
"vs1": ["aabb", "cc", "dd"],
161
"vs2": ["aa:bb", "cc", "dd"],
162
"vs3": ["aa:bb", "NA:cc", "dd:NA"],
163
"vn1": ["999777", "555", "333"],
164
"vn2": ["999:777", "555", "333"],
165
"vn3": ["999:777", "NA:555", "333:NA"],
166
}
167
),
168
)
169
with pytest.raises(
170
SQLSyntaxError,
171
match=r"ARRAY_TO_STRING expects 2-3 arguments \(found 1\)",
172
):
173
pl.sql_expr("ARRAY_TO_STRING(arr)")
174
175
176
@pytest.mark.parametrize(
177
"array_keyword",
178
["ARRAY", ""],
179
)
180
def test_unnest_table_function(array_keyword: str) -> None:
181
with pl.SQLContext(df=None, eager=True) as ctx:
182
res = ctx.execute(
183
f"""
184
SELECT * FROM
185
UNNEST(
186
{array_keyword}[1, 2, 3, 4],
187
{array_keyword}['ww','xx','yy','zz'],
188
{array_keyword}[23.0, 24.5, 28.0, 27.5]
189
) AS tbl (x,y,z);
190
"""
191
)
192
assert_frame_equal(
193
res,
194
pl.DataFrame(
195
{
196
"x": [1, 2, 3, 4],
197
"y": ["ww", "xx", "yy", "zz"],
198
"z": [23.0, 24.5, 28.0, 27.5],
199
}
200
),
201
)
202
203
204
def test_unnest_table_function_errors() -> None:
205
with pl.SQLContext(df=None, eager=True) as ctx:
206
with pytest.raises(
207
SQLSyntaxError,
208
match=r'UNNEST table alias must also declare column names, eg: "frame data" \(a,b,c\)',
209
):
210
ctx.execute('SELECT * FROM UNNEST([1, 2, 3]) AS "frame data"')
211
212
with pytest.raises(
213
SQLSyntaxError,
214
match="UNNEST table alias requires 1 column name, found 2",
215
):
216
ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) AS tbl (a, b)")
217
218
with pytest.raises(
219
SQLSyntaxError,
220
match="UNNEST table alias requires 2 column names, found 1",
221
):
222
ctx.execute("SELECT * FROM UNNEST([1,2,3], [3,4,5]) AS tbl (a)")
223
224
with pytest.raises(
225
SQLSyntaxError,
226
match=r"UNNEST table must have an alias",
227
):
228
ctx.execute("SELECT * FROM UNNEST([1, 2, 3])")
229
230
with pytest.raises(
231
SQLInterfaceError,
232
match=r"UNNEST tables do not \(yet\) support WITH OFFSET|ORDINALITY",
233
):
234
ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) tbl (colx) WITH OFFSET")
235
236