Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_strings.py
6939 views
1
from __future__ import annotations
2
3
from pathlib import Path
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import SQLSyntaxError
9
from polars.testing import assert_frame_equal
10
11
12
# TODO: Do not rely on I/O for these tests
13
@pytest.fixture
14
def foods_ipc_path() -> Path:
15
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"
16
17
18
def test_string_case() -> None:
19
df = pl.DataFrame({"words": ["Test SOME words"]})
20
21
with pl.SQLContext(frame=df) as ctx:
22
res = ctx.execute(
23
"""
24
SELECT
25
words,
26
INITCAP(words) as cap,
27
UPPER(words) as upper,
28
LOWER(words) as lower,
29
FROM frame
30
"""
31
).collect()
32
33
assert res.to_dict(as_series=False) == {
34
"words": ["Test SOME words"],
35
"cap": ["Test Some Words"],
36
"upper": ["TEST SOME WORDS"],
37
"lower": ["test some words"],
38
}
39
40
41
def test_string_concat() -> None:
42
lf = pl.LazyFrame(
43
{
44
"x": ["a", None, "c"],
45
"y": ["d", "e", "f"],
46
"z": [1, 2, 3],
47
}
48
)
49
res = lf.sql(
50
"""
51
SELECT
52
("x" || "x" || "y") AS c0,
53
("x" || "y" || "z") AS c1,
54
CONCAT(("x" || '-'), "y") AS c2,
55
CONCAT("x", "x", "y") AS c3,
56
CONCAT("x", "y", ("z" * 2)) AS c4,
57
CONCAT_WS(':', "x", "y", "z") AS c5,
58
CONCAT_WS('', "y", "z", '!') AS c6
59
FROM self
60
""",
61
).collect()
62
63
assert res.to_dict(as_series=False) == {
64
"c0": ["aad", None, "ccf"],
65
"c1": ["ad1", None, "cf3"],
66
"c2": ["a-d", "e", "c-f"],
67
"c3": ["aad", "e", "ccf"],
68
"c4": ["ad2", "e4", "cf6"],
69
"c5": ["a:d:1", "e:2", "c:f:3"],
70
"c6": ["d1!", "e2!", "f3!"],
71
}
72
73
74
@pytest.mark.parametrize(
75
"invalid_concat", ["CONCAT()", "CONCAT_WS()", "CONCAT_WS(':')"]
76
)
77
def test_string_concat_errors(invalid_concat: str) -> None:
78
lf = pl.LazyFrame({"x": ["a", "b", "c"]})
79
with pytest.raises(
80
SQLSyntaxError,
81
match=r"CONCAT.*expects at least \d argument[s]? \(found \d\)",
82
):
83
pl.SQLContext(data=lf).execute(f"SELECT {invalid_concat} FROM data")
84
85
86
def test_string_left_right_reverse() -> None:
87
df = pl.DataFrame({"txt": ["abcde", "abc", "a", None]})
88
ctx = pl.SQLContext(df=df)
89
res = ctx.execute(
90
"""
91
SELECT
92
LEFT(txt,2) AS "l",
93
RIGHT(txt,2) AS "r",
94
REVERSE(txt) AS "rev"
95
FROM df
96
""",
97
).collect()
98
99
assert res.to_dict(as_series=False) == {
100
"l": ["ab", "ab", "a", None],
101
"r": ["de", "bc", "a", None],
102
"rev": ["edcba", "cba", "a", None],
103
}
104
for func, invalid_arg, invalid_err in (
105
("LEFT", "'xyz'", '"xyz"'),
106
("RIGHT", "6.66", "(dyn float: 6.66)"),
107
):
108
with pytest.raises(
109
SQLSyntaxError,
110
match=rf"""invalid 'n_chars' for {func} \({invalid_err}\)""",
111
):
112
ctx.execute(f"""SELECT {func}(txt,{invalid_arg}) FROM df""").collect()
113
114
115
def test_string_left_negative_expr() -> None:
116
# negative values and expressions
117
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
118
with pl.SQLContext(df=df, eager=True) as sql:
119
res = sql.execute(
120
"""
121
SELECT
122
LEFT("s",-50) AS l0, -- empty string
123
LEFT("s",-3) AS l1, -- all but last three chars
124
LEFT("s",SIGN(-1)) AS l2, -- all but last char (expr => -1)
125
LEFT("s",0) AS l3, -- empty string
126
LEFT("s",NULL) AS l4, -- null
127
LEFT("s",1) AS l5, -- first char
128
LEFT("s",SIGN(1)) AS l6, -- first char (expr => 1)
129
LEFT("s",3) AS l7, -- first three chars
130
LEFT("s",50) AS l8, -- entire string
131
LEFT("s","n") AS l9, -- from other col
132
FROM df
133
"""
134
)
135
assert res.to_dict(as_series=False) == {
136
"l0": ["", ""],
137
"l1": ["alpha", "alpha"],
138
"l2": ["alphabe", "alphabe"],
139
"l3": ["", ""],
140
"l4": [None, None],
141
"l5": ["a", "a"],
142
"l6": ["a", "a"],
143
"l7": ["alp", "alp"],
144
"l8": ["alphabet", "alphabet"],
145
"l9": ["al", "alphab"],
146
}
147
148
149
def test_string_right_negative_expr() -> None:
150
# negative values and expressions
151
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
152
with pl.SQLContext(df=df, eager=True) as sql:
153
res = sql.execute(
154
"""
155
SELECT
156
RIGHT("s",-50) AS l0, -- empty string
157
RIGHT("s",-3) AS l1, -- all but first three chars
158
RIGHT("s",SIGN(-1)) AS l2, -- all but first char (expr => -1)
159
RIGHT("s",0) AS l3, -- empty string
160
RIGHT("s",NULL) AS l4, -- null
161
RIGHT("s",1) AS l5, -- last char
162
RIGHT("s",SIGN(1)) AS l6, -- last char (expr => 1)
163
RIGHT("s",3) AS l7, -- last three chars
164
RIGHT("s",50) AS l8, -- entire string
165
RIGHT("s","n") AS l9, -- from other col
166
FROM df
167
"""
168
)
169
assert res.to_dict(as_series=False) == {
170
"l0": ["", ""],
171
"l1": ["habet", "habet"],
172
"l2": ["lphabet", "lphabet"],
173
"l3": ["", ""],
174
"l4": [None, None],
175
"l5": ["t", "t"],
176
"l6": ["t", "t"],
177
"l7": ["bet", "bet"],
178
"l8": ["alphabet", "alphabet"],
179
"l9": ["et", "phabet"],
180
}
181
182
183
def test_string_lengths() -> None:
184
df = pl.DataFrame({"words": ["Café", None, "東京", ""]})
185
186
with pl.SQLContext(frame=df) as ctx:
187
res = ctx.execute(
188
"""
189
SELECT
190
words,
191
LENGTH(words) AS n_chrs1,
192
CHAR_LENGTH(words) AS n_chrs2,
193
CHARACTER_LENGTH(words) AS n_chrs3,
194
OCTET_LENGTH(words) AS n_bytes,
195
BIT_LENGTH(words) AS n_bits
196
FROM frame
197
"""
198
).collect()
199
200
assert res.to_dict(as_series=False) == {
201
"words": ["Café", None, "東京", ""],
202
"n_chrs1": [4, None, 2, 0],
203
"n_chrs2": [4, None, 2, 0],
204
"n_chrs3": [4, None, 2, 0],
205
"n_bytes": [5, None, 6, 0],
206
"n_bits": [40, None, 48, 0],
207
}
208
209
210
@pytest.mark.parametrize(
211
("pattern", "like", "expected"),
212
[
213
("a%", "LIKE", [1, 4]),
214
("a%", "ILIKE", [0, 1, 3, 4]),
215
("ab%", "LIKE", [1]),
216
("AB%", "ILIKE", [0, 1]),
217
("ab_", "LIKE", [1]),
218
("A__", "ILIKE", [0, 1]),
219
("_0%_", "LIKE", [2, 4]),
220
("%0", "LIKE", [2]),
221
("0%", "LIKE", [2]),
222
("__0%", "~~", [2, 3]),
223
("%*%", "~~*", [3]),
224
("____", "~~", [4]),
225
("a%C", "~~", []),
226
("a%C", "~~*", [0, 1, 3]),
227
("%C?", "~~*", [4]),
228
("a0c?", "~~", [4]),
229
("000", "~~", [2]),
230
("00", "~~", []),
231
],
232
)
233
def test_string_like(pattern: str, like: str, expected: list[int]) -> None:
234
df = pl.DataFrame(
235
{
236
"idx": [0, 1, 2, 3, 4],
237
"txt": ["ABC", "abc", "000", "A[0]*C", "a0c?"],
238
}
239
)
240
with pl.SQLContext(df=df) as ctx:
241
for not_ in ("", ("NOT " if like.endswith("LIKE") else "!")):
242
out = ctx.execute(
243
f"SELECT idx FROM df WHERE txt {not_}{like} '{pattern}'"
244
).collect()
245
246
res = out["idx"].to_list()
247
if not_:
248
expected = [i for i in df["idx"] if i not in expected]
249
assert res == expected
250
251
252
def test_string_like_multiline() -> None:
253
s1 = "Hello World"
254
s2 = "Hello\nWorld"
255
s3 = "hello\nWORLD"
256
257
df = pl.DataFrame({"idx": [0, 1, 2], "txt": [s1, s2, s3]})
258
259
# starts with...
260
res1 = df.sql("SELECT * FROM self WHERE txt LIKE 'Hello%' ORDER BY idx")
261
res2 = df.sql("SELECT * FROM self WHERE txt ILIKE 'HELLO%' ORDER BY idx")
262
263
assert res1["txt"].to_list() == [s1, s2]
264
assert res2["txt"].to_list() == [s1, s2, s3]
265
266
# ends with...
267
res3 = df.sql("SELECT * FROM self WHERE txt LIKE '%WORLD' ORDER BY idx")
268
res4 = df.sql("SELECT * FROM self WHERE txt ILIKE '%\nWORLD' ORDER BY idx")
269
270
assert res3["txt"].to_list() == [s3]
271
assert res4["txt"].to_list() == [s2, s3]
272
273
# exact match
274
for s in (s1, s2, s3):
275
assert df.sql(f"SELECT txt FROM self WHERE txt LIKE '{s}'").item() == s
276
277
278
@pytest.mark.parametrize("form", ["NFKC", "NFKD"])
279
def test_string_normalize(form: str) -> None:
280
df = pl.DataFrame({"txt": ["Test", "𝕋𝕖𝕤𝕥", "𝕿𝖊𝖘𝖙", "𝗧𝗲𝘀𝘁", "Ⓣⓔⓢⓣ"]}) # noqa: RUF001
281
res = df.sql(
282
f"""
283
SELECT txt, NORMALIZE(txt,{form}) AS norm_txt
284
FROM self
285
"""
286
)
287
assert res.to_dict(as_series=False) == {
288
"txt": ["Test", "𝕋𝕖𝕤𝕥", "𝕿𝖊𝖘𝖙", "𝗧𝗲𝘀𝘁", "Ⓣⓔⓢⓣ"], # noqa: RUF001
289
"norm_txt": ["Test", "Test", "Test", "Test", "Test"],
290
}
291
292
293
def test_string_position() -> None:
294
df = pl.Series(
295
name="city",
296
values=["Dubai", "Abu Dhabi", "Sharjah", "Al Ain", "Ajman", "Ras Al Khaimah"],
297
).to_frame()
298
299
with pl.SQLContext(cities=df, eager=True) as ctx:
300
res = ctx.execute(
301
"""
302
SELECT
303
POSITION('a' IN city) AS a_lc1,
304
POSITION('A' IN city) AS a_uc1,
305
STRPOS(city,'a') AS a_lc2,
306
STRPOS(city,'A') AS a_uc2,
307
FROM cities
308
"""
309
)
310
expected_lc = [4, 7, 3, 0, 4, 2]
311
expected_uc = [0, 1, 0, 1, 1, 5]
312
313
assert res.to_dict(as_series=False) == {
314
"a_lc1": expected_lc,
315
"a_uc1": expected_uc,
316
"a_lc2": expected_lc,
317
"a_uc2": expected_uc,
318
}
319
320
df = pl.DataFrame({"txt": ["AbCdEXz", "XyzFDkE"]})
321
with pl.SQLContext(txt=df) as ctx:
322
res = ctx.execute(
323
"""
324
SELECT
325
txt,
326
POSITION('E' IN txt) AS match_E,
327
STRPOS(txt,'X') AS match_X
328
FROM txt
329
""",
330
eager=True,
331
)
332
assert_frame_equal(
333
res,
334
pl.DataFrame(
335
data={
336
"txt": ["AbCdEXz", "XyzFDkE"],
337
"match_E": [5, 7],
338
"match_X": [6, 1],
339
},
340
schema={
341
"txt": pl.String,
342
"match_E": pl.UInt32,
343
"match_X": pl.UInt32,
344
},
345
),
346
)
347
348
349
def test_string_replace() -> None:
350
df = pl.DataFrame({"words": ["Yemeni coffee is the best coffee", "", None]})
351
with pl.SQLContext(df=df) as ctx:
352
out = ctx.execute(
353
"""
354
SELECT
355
REPLACE(
356
REPLACE(words, 'coffee', 'tea'),
357
'Yemeni',
358
'English breakfast'
359
)
360
FROM df
361
"""
362
).collect()
363
364
res = out["words"].to_list()
365
assert res == ["English breakfast tea is the best tea", "", None]
366
367
with pytest.raises(
368
SQLSyntaxError, match=r"REPLACE expects 3 arguments \(found 2\)"
369
):
370
ctx.execute("SELECT REPLACE(words,'coffee') FROM df")
371
372
373
def test_string_split() -> None:
374
df = pl.DataFrame({"s": ["xx,yy,zz", "abc,,xyz", "", None]})
375
res = df.sql("SELECT *, STRING_TO_ARRAY(s,',') AS s_array FROM self")
376
377
assert res.schema == {"s": pl.String, "s_array": pl.List(pl.String)}
378
assert res.to_dict(as_series=False) == {
379
"s": ["xx,yy,zz", "abc,,xyz", "", None],
380
"s_array": [["xx", "yy", "zz"], ["abc", "", "xyz"], [""], None],
381
}
382
383
384
def test_string_split_part() -> None:
385
df = pl.DataFrame({"s": ["xx,yy,zz", "abc,,xyz,???,hmm", "", None]})
386
res = df.sql(
387
"""
388
SELECT
389
SPLIT_PART(s,',',1) AS "s+1",
390
SPLIT_PART(s,',',3) AS "s+3",
391
SPLIT_PART(s,',',-2) AS "s-2",
392
FROM self
393
"""
394
)
395
assert res.to_dict(as_series=False) == {
396
"s+1": ["xx", "abc", "", None],
397
"s+3": ["zz", "xyz", "", None],
398
"s-2": ["yy", "???", "", None],
399
}
400
401
402
def test_string_substr() -> None:
403
df = pl.DataFrame(
404
{"scol": ["abcdefg", "abcde", "abc", None], "n": [-2, 3, 2, None]}
405
)
406
with pl.SQLContext(df=df) as ctx:
407
res = ctx.execute(
408
"""
409
SELECT
410
-- note: sql is 1-indexed
411
SUBSTR(scol,1) AS s1,
412
SUBSTR(scol,2) AS s2,
413
SUBSTR(scol,3) AS s3,
414
SUBSTR(scol,1,5) AS s1_5,
415
SUBSTR(scol,2,2) AS s2_2,
416
SUBSTR(scol,3,1) AS s3_1,
417
SUBSTR(scol,-3) AS "s-3",
418
SUBSTR(scol,-3,3) AS "s-3_3",
419
SUBSTR(scol,-3,4) AS "s-3_4",
420
SUBSTR(scol,-3,5) AS "s-3_5",
421
SUBSTR(scol,-10,13) AS "s-10_13",
422
SUBSTR(scol,"n",2) AS "s-n2",
423
SUBSTR(scol,2,"n"+3) AS "s-2n3"
424
FROM df
425
"""
426
).collect()
427
428
with pytest.raises(
429
SQLSyntaxError,
430
match=r"SUBSTR does not support negative length \(-99\)",
431
):
432
ctx.execute("SELECT SUBSTR(scol,2,-99) FROM df")
433
434
with pytest.raises(
435
SQLSyntaxError,
436
match=r"SUBSTR expects 2-3 arguments \(found 1\)",
437
):
438
pl.sql_expr("SUBSTR(s)")
439
440
assert res.to_dict(as_series=False) == {
441
"s1": ["abcdefg", "abcde", "abc", None],
442
"s2": ["bcdefg", "bcde", "bc", None],
443
"s3": ["cdefg", "cde", "c", None],
444
"s1_5": ["abcde", "abcde", "abc", None],
445
"s2_2": ["bc", "bc", "bc", None],
446
"s3_1": ["c", "c", "c", None],
447
"s-3": ["abcdefg", "abcde", "abc", None],
448
"s-3_3": ["", "", "", None],
449
"s-3_4": ["", "", "", None],
450
"s-3_5": ["a", "a", "a", None],
451
"s-10_13": ["ab", "ab", "ab", None],
452
"s-n2": ["", "cd", "bc", None],
453
"s-2n3": ["b", "bcde", "bc", None],
454
}
455
456
457
def test_string_trim(foods_ipc_path: Path) -> None:
458
lf = pl.scan_ipc(foods_ipc_path)
459
out = lf.sql(
460
"""
461
SELECT DISTINCT TRIM(LEADING 'vmf' FROM category) as new_category
462
FROM self ORDER BY new_category DESC
463
"""
464
).collect()
465
assert out.to_dict(as_series=False) == {
466
"new_category": ["seafood", "ruit", "egetables", "eat"]
467
}
468
with pytest.raises(
469
SQLSyntaxError,
470
match="unsupported TRIM syntax",
471
):
472
# currently unsupported (snowflake-style) trim syntax
473
lf.sql("SELECT DISTINCT TRIM('*^xxxx^*', '^*') as new_category FROM self")
474
475