Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/sql/test_joins.py
8412 views
1
from __future__ import annotations
2
3
from io import BytesIO
4
from pathlib import Path
5
from typing import Any
6
7
import pytest
8
9
import polars as pl
10
from polars.exceptions import SQLInterfaceError, SQLSyntaxError
11
from polars.testing import assert_frame_equal
12
from tests.unit.sql import assert_sql_matches
13
14
15
@pytest.fixture
16
def foods_ipc_path() -> Path:
17
return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc"
18
19
20
@pytest.mark.parametrize(
21
("sql", "expected"),
22
[
23
(
24
"SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (a,c)",
25
pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}),
26
),
27
(
28
"SELECT * FROM tbl_a SEMI JOIN tbl_b USING (a,c)",
29
pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}),
30
),
31
(
32
"SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (a)",
33
pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}),
34
),
35
(
36
"SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b USING (a)",
37
pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.String}),
38
),
39
(
40
"SELECT * FROM tbl_a ANTI JOIN tbl_b USING (a)",
41
pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.String}),
42
),
43
(
44
"SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)",
45
pl.DataFrame({"a": [1, 3], "b": [4, 6], "c": ["w", "z"]}),
46
),
47
(
48
"SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)",
49
pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}),
50
),
51
(
52
"SELECT * FROM tbl_a RIGHT ANTI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)",
53
pl.DataFrame({"a": [2], "b": [5], "c": ["y"]}),
54
),
55
(
56
"SELECT * FROM tbl_a RIGHT SEMI JOIN tbl_b USING (b) RIGHT SEMI JOIN tbl_c USING (c)",
57
pl.DataFrame({"c": ["z"], "d": [25.5]}),
58
),
59
(
60
"SELECT * FROM tbl_a RIGHT SEMI JOIN tbl_b USING (b) RIGHT ANTI JOIN tbl_c USING (c)",
61
pl.DataFrame({"c": ["w", "y"], "d": [10.5, -50.0]}),
62
),
63
],
64
)
65
def test_join_anti_semi(sql: str, expected: pl.DataFrame) -> None:
66
frames = {
67
"tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}),
68
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
69
"tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}),
70
}
71
ctx = pl.SQLContext(frames, eager=True)
72
assert_frame_equal(expected, ctx.execute(sql))
73
74
75
def test_join_cross() -> None:
76
frames = {
77
"tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}),
78
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
79
}
80
with pl.SQLContext(frames, eager=True) as ctx:
81
out = ctx.execute(
82
"""
83
SELECT *
84
FROM tbl_a
85
CROSS JOIN tbl_b
86
ORDER BY a, b, c
87
"""
88
)
89
assert out.rows() == [
90
(1, 4, "w", 3, 6, "x"),
91
(1, 4, "w", 2, 5, "y"),
92
(1, 4, "w", 1, 4, "z"),
93
(2, 0, "y", 3, 6, "x"),
94
(2, 0, "y", 2, 5, "y"),
95
(2, 0, "y", 1, 4, "z"),
96
(3, 6, "z", 3, 6, "x"),
97
(3, 6, "z", 2, 5, "y"),
98
(3, 6, "z", 1, 4, "z"),
99
]
100
101
102
def test_join_cross_11927() -> None:
103
df1 = pl.DataFrame({"id": [1, 2, 3]})
104
df2 = pl.DataFrame({"id": [3, 4, 5]})
105
df3 = pl.DataFrame({"id": [4, 5, 6]})
106
107
res = pl.sql("SELECT df1.id FROM df1 CROSS JOIN df2 WHERE df1.id = df2.id")
108
assert_frame_equal(res.collect(), pl.DataFrame({"id": [3]}))
109
110
res = pl.sql("SELECT * FROM df1 CROSS JOIN df3 WHERE df1.id = df3.id")
111
assert res.collect().is_empty()
112
113
114
def test_cross_join_unnest_from_table() -> None:
115
df = pl.DataFrame({"id": [1, 2], "items": [[100, 200], [300, 400, 500]]})
116
assert_sql_matches(
117
frames=df,
118
query="""
119
SELECT id, item
120
FROM self CROSS JOIN UNNEST(items) AS item
121
ORDER BY id DESC, item ASC
122
""",
123
compare_with="duckdb",
124
expected={
125
"id": [2, 2, 2, 1, 1],
126
"item": [300, 400, 500, 100, 200],
127
},
128
)
129
130
131
def test_cross_join_unnest_from_cte() -> None:
132
assert_sql_matches(
133
{},
134
query="""
135
WITH data AS (
136
SELECT 'xyz' AS id, [0,1,2] AS items
137
UNION ALL
138
SELECT 'abc', [3,4]
139
)
140
SELECT id, item
141
FROM data CROSS JOIN UNNEST(items) AS item
142
ORDER BY item
143
""",
144
compare_with="duckdb",
145
expected={
146
"id": ["xyz", "xyz", "xyz", "abc", "abc"],
147
"item": [0, 1, 2, 3, 4],
148
},
149
)
150
151
152
@pytest.mark.parametrize(
153
"join_clause",
154
[
155
"ON f1.category = f2.category",
156
"ON f2.category = f1.category",
157
"USING (category)",
158
],
159
)
160
def test_join_inner(foods_ipc_path: Path, join_clause: str) -> None:
161
foods1 = pl.scan_ipc(foods_ipc_path)
162
foods2 = foods1
163
schema = foods1.collect_schema()
164
165
out = pl.sql(
166
f"""
167
SELECT *
168
FROM
169
(SELECT * FROM foods1 WHERE fats_g != 0) f1
170
INNER JOIN
171
(SELECT * FROM foods2 WHERE fats_g = 0) f2
172
{join_clause}
173
ORDER BY ALL
174
LIMIT 2
175
""",
176
eager=True,
177
)
178
expected = pl.DataFrame(
179
{
180
"category": ["fruit", "fruit"],
181
"calories": [50, 50],
182
"fats_g": [4.5, 4.5],
183
"sugars_g": [0, 0],
184
"category:f2": ["fruit", "fruit"],
185
"calories:f2": [30, 30],
186
"fats_g:f2": [0.0, 0.0],
187
"sugars_g:f2": [3, 5],
188
}
189
)
190
assert_frame_equal(expected, out, check_dtypes=False)
191
192
193
def test_join_inner_15663() -> None:
194
df_a = pl.DataFrame({"LOCID": [1, 2, 3], "VALUE": [0.1, 0.2, 0.3]})
195
df_b = pl.DataFrame({"LOCID": [1, 2, 3], "VALUE": [25.6, 53.4, 12.7]})
196
df_expected = pl.DataFrame(
197
{
198
"LOCID": [1, 2, 3],
199
"VALUE_A": [0.1, 0.2, 0.3],
200
"VALUE_B": [25.6, 53.4, 12.7],
201
}
202
)
203
with pl.SQLContext(register_globals=True, eager=True) as ctx:
204
query = """
205
SELECT
206
a.LOCID,
207
a.VALUE AS VALUE_A,
208
b.VALUE AS VALUE_B
209
FROM df_a AS a INNER JOIN df_b AS b USING (LOCID)
210
ORDER BY LOCID
211
"""
212
actual = ctx.execute(query)
213
assert_frame_equal(df_expected, actual)
214
215
216
@pytest.mark.parametrize(
217
("join_clause", "expected_error"),
218
[
219
(
220
"""
221
INNER JOIN tbl_b USING (a,b)
222
INNER JOIN tbl_c USING (c)
223
""",
224
None,
225
),
226
(
227
"""
228
INNER JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b
229
INNER JOIN tbl_c ON tbl_b.c = tbl_c.c
230
""",
231
None,
232
),
233
(
234
"""
235
INNER JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b
236
INNER JOIN tbl_c ON tbl_a.c = tbl_c.c --<< (no "c" in 'tbl_a')
237
""",
238
"no column named 'c' found in table 'tbl_a'",
239
),
240
],
241
)
242
def test_join_inner_multi(join_clause: str, expected_error: str | None) -> None:
243
frames = {
244
"tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]}),
245
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
246
"tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}),
247
}
248
with pl.SQLContext(frames) as ctx:
249
assert ctx.tables() == ["tbl_a", "tbl_b", "tbl_c"]
250
query = f"""
251
SELECT tbl_a.a, tbl_a.b, tbl_b.c, tbl_c.d
252
FROM tbl_a {join_clause}
253
ORDER BY tbl_a.a DESC
254
"""
255
try:
256
out = ctx.execute(query)
257
assert out.collect().rows() == [(1, 4, "z", 25.5)]
258
259
except SQLInterfaceError as err:
260
if not (expected_error and expected_error in str(err)):
261
raise
262
263
264
@pytest.mark.parametrize(
265
"join_clause",
266
[
267
"""
268
LEFT JOIN tbl_b USING (a,b)
269
LEFT JOIN tbl_c USING (c)
270
""",
271
"""
272
LEFT JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b
273
LEFT JOIN tbl_c ON tbl_b.c = tbl_c.c
274
""",
275
],
276
)
277
def test_join_left_multi(join_clause: str) -> None:
278
frames = {
279
"tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]}),
280
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
281
"tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}),
282
}
283
with pl.SQLContext(frames) as ctx:
284
for select_cols in (
285
"tbl_a.a, tbl_a.b, tbl_b.c, tbl_c.d",
286
"tbl_a.a, tbl_a.b, tbl_b.c, d",
287
):
288
out = ctx.execute(
289
f"SELECT {select_cols} FROM tbl_a {join_clause} ORDER BY a DESC"
290
)
291
assert out.collect().rows() == [
292
(3, 6, "x", None),
293
(2, None, None, None),
294
(1, 4, "z", 25.5),
295
]
296
297
298
def test_join_left_multi_nested() -> None:
299
frames = {
300
"tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]}),
301
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
302
"tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}),
303
}
304
with pl.SQLContext(frames) as ctx:
305
out = ctx.execute(
306
"""
307
SELECT tbl_x.a, tbl_x.b, tbl_x.c, tbl_c.d FROM (
308
SELECT *
309
FROM tbl_a
310
LEFT JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b
311
) tbl_x
312
LEFT JOIN tbl_c ON tbl_x.c = tbl_c.c
313
ORDER BY tbl_x.a ASC
314
"""
315
).collect()
316
317
assert out.rows() == [
318
(1, 4, "z", 25.5),
319
(2, None, None, None),
320
(3, 6, "x", None),
321
]
322
323
324
def test_join_misc_13618() -> None:
325
import polars as pl
326
327
df = pl.DataFrame(
328
{
329
"A": [1, 2, 3, 4, 5],
330
"B": [5, 4, 3, 2, 1],
331
"fruits": ["banana", "banana", "apple", "apple", "banana"],
332
"cars": ["beetle", "audi", "beetle", "beetle", "beetle"],
333
}
334
)
335
res = (
336
pl.SQLContext(t=df, t1=df, eager=True)
337
.execute(
338
"""
339
SELECT t.A, t.fruits, t1.B, t1.cars
340
FROM t
341
JOIN t1 ON t.A = t1.B
342
ORDER BY t.A DESC
343
"""
344
)
345
.to_dict(as_series=False)
346
)
347
assert res == {
348
"A": [5, 4, 3, 2, 1],
349
"fruits": ["banana", "apple", "apple", "banana", "banana"],
350
"B": [5, 4, 3, 2, 1],
351
"cars": ["beetle", "audi", "beetle", "beetle", "beetle"],
352
}
353
354
355
def test_join_misc_16255() -> None:
356
df1 = pl.read_csv(BytesIO(b"id,data\n1,open"))
357
df2 = pl.read_csv(BytesIO(b"id,data\n1,closed"))
358
res = pl.sql(
359
"""
360
SELECT a.id, a.data AS d1, b.data AS d2
361
FROM df1 AS a JOIN df2 AS b
362
ON a.id = b.id
363
""",
364
eager=True,
365
)
366
assert res.rows() == [(1, "open", "closed")]
367
368
369
@pytest.mark.parametrize(
370
"constraint", ["tbl.a != tbl.b", "tbl.a > tbl.b", "a >= b", "a < b", "b <= a"]
371
)
372
def test_non_equi_joins(constraint: str) -> None:
373
# no support (yet) for non equi-joins in polars joins
374
# TODO: integrate awareness of new IEJoin
375
with (
376
pytest.raises(
377
SQLInterfaceError,
378
match=r"only equi-join constraints \(combined with 'AND'\) are currently supported",
379
),
380
pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx,
381
):
382
ctx.execute(
383
f"""
384
SELECT *
385
FROM tbl
386
LEFT JOIN tbl ON {constraint} -- not an equi-join
387
"""
388
)
389
390
391
def test_implicit_joins() -> None:
392
# no support for this yet; ensure we catch it
393
with (
394
pytest.raises(
395
SQLInterfaceError,
396
match=r"not currently supported .* use explicit JOIN syntax instead",
397
),
398
pl.SQLContext(
399
{
400
"tbl": pl.DataFrame(
401
{"a": [1, 2, 3], "b": [4, 3, 2], "c": ["x", "y", "z"]}
402
)
403
}
404
) as ctx,
405
):
406
ctx.execute(
407
"""
408
SELECT t1.*
409
FROM tbl AS t1, tbl AS t2
410
WHERE t1.a = t2.b
411
"""
412
)
413
414
415
@pytest.mark.parametrize(
416
("query", "expected"),
417
[
418
# INNER joins
419
(
420
"SELECT df1.* FROM df1 INNER JOIN df2 USING (a)",
421
{"a": [1, 3], "b": ["x", "z"], "c": [100, 300]},
422
),
423
(
424
"SELECT df2.* FROM df1 INNER JOIN df2 USING (a)",
425
{"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]},
426
),
427
(
428
"SELECT df1.* FROM df2 INNER JOIN df1 USING (a)",
429
{"a": [1, 3], "b": ["x", "z"], "c": [100, 300]},
430
),
431
(
432
"SELECT df2.* FROM df2 INNER JOIN df1 USING (a)",
433
{"a": [1, 3], "b": ["qq", "pp"], "c": [400, 500]},
434
),
435
# LEFT joins
436
(
437
"SELECT df1.* FROM df1 LEFT JOIN df2 USING (a)",
438
{"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]},
439
),
440
(
441
"SELECT df2.* FROM df1 LEFT JOIN df2 USING (a)",
442
{"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]},
443
),
444
(
445
"SELECT df1.* FROM df2 LEFT JOIN df1 USING (a)",
446
{"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]},
447
),
448
(
449
"SELECT df2.* FROM df2 LEFT JOIN df1 USING (a)",
450
{"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]},
451
),
452
# RIGHT joins
453
(
454
"SELECT df1.* FROM df1 RIGHT JOIN df2 USING (a)",
455
{"a": [1, 3, None], "b": ["x", "z", None], "c": [100, 300, None]},
456
),
457
(
458
"SELECT df2.* FROM df1 RIGHT JOIN df2 USING (a)",
459
{"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]},
460
),
461
(
462
"SELECT df1.* FROM df2 RIGHT JOIN df1 USING (a)",
463
{"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]},
464
),
465
(
466
"SELECT df2.* FROM df2 RIGHT JOIN df1 USING (a)",
467
{"a": [1, 3, None], "b": ["qq", "pp", None], "c": [400, 500, None]},
468
),
469
# FULL joins
470
(
471
"SELECT df1.* FROM df1 FULL JOIN df2 USING (a)",
472
{
473
"a": [1, 2, 3, None],
474
"b": ["x", "y", "z", None],
475
"c": [100, 200, 300, None],
476
},
477
),
478
(
479
"SELECT df2.* FROM df1 FULL JOIN df2 USING (a)",
480
{
481
"a": [1, 3, 4, None],
482
"b": ["qq", "pp", "oo", None],
483
"c": [400, 500, 600, None],
484
},
485
),
486
(
487
"SELECT df1.* FROM df2 FULL JOIN df1 USING (a)",
488
{
489
"a": [1, 2, 3, None],
490
"b": ["x", "y", "z", None],
491
"c": [100, 200, 300, None],
492
},
493
),
494
(
495
"SELECT df2.* FROM df2 FULL JOIN df1 USING (a)",
496
{
497
"a": [1, 3, 4, None],
498
"b": ["qq", "pp", "oo", None],
499
"c": [400, 500, 600, None],
500
},
501
),
502
],
503
)
504
def test_wildcard_resolution_and_join_order(
505
query: str, expected: dict[str, Any]
506
) -> None:
507
df1 = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [100, 200, 300]})
508
df2 = pl.DataFrame({"a": [1, 3, 4], "b": ["qq", "pp", "oo"], "c": [400, 500, 600]})
509
510
res = pl.sql(query).collect()
511
assert_frame_equal(
512
res,
513
pl.DataFrame(expected),
514
check_row_order=False,
515
)
516
517
518
def test_natural_joins_01() -> None:
519
df1 = pl.DataFrame(
520
{
521
"CharacterID": [1, 2, 3, 4],
522
"FirstName": ["Jernau Morat", "Cheradenine", "Byr", "Diziet"],
523
"LastName": ["Gurgeh", "Zakalwe", "Genar-Hofoen", "Sma"],
524
}
525
)
526
df2 = pl.DataFrame(
527
{
528
"CharacterID": [1, 2, 3, 5],
529
"Role": ["Protagonist", "Protagonist", "Protagonist", "Antagonist"],
530
"Book": [
531
"Player of Games",
532
"Use of Weapons",
533
"Excession",
534
"Consider Phlebas",
535
],
536
}
537
)
538
df3 = pl.DataFrame(
539
{
540
"CharacterID": [1, 2, 3, 4],
541
"Affiliation": ["Culture", "Culture", "Culture", "Shellworld"],
542
"Species": ["Pan-human", "Human", "Human", "Oct"],
543
}
544
)
545
df4 = pl.DataFrame(
546
{
547
"CharacterID": [1, 2, 3, 6],
548
"Ship": [
549
"Limiting Factor",
550
"Xenophobe",
551
"Grey Area",
552
"Falling Outside The Normal Moral Constraints",
553
],
554
"Drone": ["Flere-Imsaho", "Skaffen-Amtiskaw", "Eccentric", "Psychopath"],
555
}
556
)
557
with pl.SQLContext(
558
{"df1": df1, "df2": df2, "df3": df3, "df4": df4}, eager=True
559
) as ctx:
560
res = ctx.execute(
561
"""
562
SELECT *
563
FROM df1
564
NATURAL LEFT JOIN df2
565
NATURAL INNER JOIN df3
566
NATURAL LEFT JOIN df4
567
ORDER BY ALL
568
"""
569
)
570
assert res.rows(named=True) == [
571
{
572
"CharacterID": 1,
573
"FirstName": "Jernau Morat",
574
"LastName": "Gurgeh",
575
"Role": "Protagonist",
576
"Book": "Player of Games",
577
"Affiliation": "Culture",
578
"Species": "Pan-human",
579
"Ship": "Limiting Factor",
580
"Drone": "Flere-Imsaho",
581
},
582
{
583
"CharacterID": 2,
584
"FirstName": "Cheradenine",
585
"LastName": "Zakalwe",
586
"Role": "Protagonist",
587
"Book": "Use of Weapons",
588
"Affiliation": "Culture",
589
"Species": "Human",
590
"Ship": "Xenophobe",
591
"Drone": "Skaffen-Amtiskaw",
592
},
593
{
594
"CharacterID": 3,
595
"FirstName": "Byr",
596
"LastName": "Genar-Hofoen",
597
"Role": "Protagonist",
598
"Book": "Excession",
599
"Affiliation": "Culture",
600
"Species": "Human",
601
"Ship": "Grey Area",
602
"Drone": "Eccentric",
603
},
604
{
605
"CharacterID": 4,
606
"FirstName": "Diziet",
607
"LastName": "Sma",
608
"Role": None,
609
"Book": None,
610
"Affiliation": "Shellworld",
611
"Species": "Oct",
612
"Ship": None,
613
"Drone": None,
614
},
615
]
616
617
# misc errors
618
with pytest.raises(SQLSyntaxError, match=r"did you mean COLUMNS\(\*\)\?"):
619
pl.sql("SELECT * FROM df1 NATURAL JOIN df2 WHERE COLUMNS('*') >= 5")
620
621
with pytest.raises(SQLSyntaxError, match=r"COLUMNS expects a regex"):
622
pl.sql("SELECT COLUMNS(1234) FROM df1 NATURAL JOIN df2")
623
624
625
@pytest.mark.parametrize(
626
("cols_constraint", "expect_data"),
627
[
628
(">= 5", [(8, 8, 6)]),
629
("< 7", [(5, 4, 4)]),
630
("< 8", [(5, 4, 4), (7, 4, 4), (0, 7, 2)]),
631
("!= 4", [(8, 8, 6), (2, 8, 6), (0, 7, 2)]),
632
],
633
)
634
def test_natural_joins_02(cols_constraint: str, expect_data: list[tuple[int]]) -> None:
635
df1 = pl.DataFrame(
636
{
637
"x": [1, 5, 3, 8, 6, 7, 4, 0, 2],
638
"y": [3, 4, 6, 8, 3, 4, 1, 7, 8],
639
}
640
)
641
df2 = pl.DataFrame(
642
{
643
"y": [0, 4, 0, 8, 0, 4, 0, 7, None],
644
"z": [9, 8, 7, 6, 5, 4, 3, 2, 1],
645
},
646
)
647
actual = pl.sql(
648
f"""
649
SELECT *
650
FROM df1 NATURAL JOIN df2
651
WHERE COLUMNS(*) {cols_constraint}
652
"""
653
).collect()
654
655
df_expected = pl.DataFrame(expect_data, schema=actual.columns, orient="row")
656
assert_frame_equal(actual, df_expected, check_row_order=False)
657
658
659
@pytest.mark.parametrize(
660
"join_clause",
661
[
662
"""
663
df2 JOIN df3 ON
664
df2.CharacterID = df3.CharacterID
665
""",
666
"""
667
df2 INNER JOIN (
668
df3 JOIN df4 ON df3.CharacterID = df4.CharacterID
669
) AS r0 ON df2.CharacterID = df3.CharacterID
670
""",
671
],
672
)
673
def test_nested_join(join_clause: str) -> None:
674
df1 = pl.DataFrame(
675
{
676
"CharacterID": [1, 2, 3, 4],
677
"FirstName": ["Jernau Morat", "Cheradenine", "Byr", "Diziet"],
678
"LastName": ["Gurgeh", "Zakalwe", "Genar-Hofoen", "Sma"],
679
}
680
)
681
df2 = pl.DataFrame(
682
{
683
"CharacterID": [1, 2, 3, 5],
684
"Role": ["Protagonist", "Protagonist", "Protagonist", "Antagonist"],
685
"Book": [
686
"Player of Games",
687
"Use of Weapons",
688
"Excession",
689
"Consider Phlebas",
690
],
691
}
692
)
693
df3 = pl.DataFrame(
694
{
695
"CharacterID": [1, 2, 5, 6],
696
"Affiliation": ["Culture", "Culture", "Culture", "Shellworld"],
697
"Species": ["Pan-human", "Human", "Human", "Oct"],
698
}
699
)
700
df4 = pl.DataFrame(
701
{
702
"CharacterID": [1, 2, 3, 6],
703
"Ship": [
704
"Limiting Factor",
705
"Xenophobe",
706
"Grey Area",
707
"Falling Outside The Normal Moral Constraints",
708
],
709
"Drone": ["Flere-Imsaho", "Skaffen-Amtiskaw", "Eccentric", "Psychopath"],
710
}
711
)
712
713
with pl.SQLContext(
714
{"df1": df1, "df2": df2, "df3": df3, "df4": df4}, eager=True
715
) as ctx:
716
res = ctx.execute(
717
f"""
718
SELECT df1.CharacterID, df1.FirstName, df2.Role, df3.Species
719
FROM df1
720
INNER JOIN ({join_clause}) AS r99
721
ON df1.CharacterID = df2.CharacterID
722
ORDER BY ALL
723
"""
724
)
725
assert res.rows(named=True) == [
726
{
727
"CharacterID": 1,
728
"FirstName": "Jernau Morat",
729
"Role": "Protagonist",
730
"Species": "Pan-human",
731
},
732
{
733
"CharacterID": 2,
734
"FirstName": "Cheradenine",
735
"Role": "Protagonist",
736
"Species": "Human",
737
},
738
]
739
740
741
def test_miscellaneous_cte_join_aliasing() -> None:
742
ctx = pl.SQLContext()
743
res = ctx.execute(
744
"""
745
WITH t AS (SELECT a FROM (VALUES(1),(2)) tbl(a))
746
SELECT * FROM t CROSS JOIN t
747
""",
748
eager=True,
749
)
750
assert sorted(res.rows()) == [
751
(1, 1),
752
(1, 2),
753
(2, 1),
754
(2, 2),
755
]
756
757
758
def test_nested_joins_17381() -> None:
759
df = pl.DataFrame({"id": ["one", "two"]})
760
761
ctx = pl.SQLContext({"a": df})
762
res = ctx.execute(
763
"""
764
-- the interaction of the (unused) CTE and the nested subquery resulted
765
-- in arena mutation/cleanup that wasn't accounted for, affecting state
766
WITH c AS (SELECT a.id FROM a)
767
SELECT *
768
FROM a
769
WHERE id IN (
770
SELECT a2.id
771
FROM a
772
INNER JOIN a AS a2 ON a.id = a2.id
773
)
774
""",
775
eager=True,
776
)
777
assert set(res["id"]) == {"one", "two"}
778
779
780
def test_unnamed_nested_join_relation() -> None:
781
df = pl.DataFrame({"a": 1})
782
783
with (
784
pl.SQLContext({"left": df, "right": df}) as ctx,
785
pytest.raises(SQLInterfaceError, match="cannot JOIN on unnamed relation"),
786
):
787
ctx.execute(
788
"""
789
SELECT *
790
FROM left
791
JOIN (right JOIN right ON right.a = right.a)
792
ON left.a = right.a
793
"""
794
)
795
796
797
def test_nulls_equal_19624() -> None:
798
df1 = pl.DataFrame({"a": [1, 2, None, None]})
799
df2 = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]})
800
801
# left join
802
res_df = df1.join(df2, how="left", on="a", nulls_equal=False, validate="1:m")
803
expected_df = pl.DataFrame(
804
{"a": [1, 1, 2, 2, None, None], "b": [0, 1, 2, 3, None, None]}
805
)
806
assert_frame_equal(res_df, expected_df)
807
res_df = df2.join(df1, how="left", on="a", nulls_equal=False, validate="m:1")
808
expected_df = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]})
809
assert_frame_equal(res_df, expected_df)
810
811
# inner join
812
res_df = df1.join(df2, how="inner", on="a", nulls_equal=False, validate="1:m")
813
expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]})
814
assert_frame_equal(res_df, expected_df)
815
res_df = df2.join(df1, how="inner", on="a", nulls_equal=False, validate="m:1")
816
expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]})
817
assert_frame_equal(res_df, expected_df)
818
819
820
def test_join_on_literal_string_comparison() -> None:
821
df1 = pl.DataFrame(
822
{
823
"name": ["alice", "bob", "adam", "charlie"],
824
"role": ["admin", "user", "admin", "user"],
825
}
826
)
827
df2 = pl.DataFrame(
828
{
829
"name": ["alice", "bob", "charlie", "adam"],
830
"dept": ["IT", "HR", "IT", "SEC"],
831
}
832
)
833
query = """
834
SELECT df1.name, df1.role, df2.dept
835
FROM df1
836
INNER JOIN df2 ON df1.name = df2.name AND df1.role = 'admin'
837
ORDER BY df1.name
838
"""
839
df_expected = pl.DataFrame(
840
data=[("adam", "admin", "SEC"), ("alice", "admin", "IT")],
841
schema={"name": str, "role": str, "dept": str},
842
orient="row",
843
)
844
res = pl.sql(query, eager=True)
845
assert_frame_equal(res, df_expected)
846
847
848
@pytest.mark.parametrize(
849
("expression", "expected_length"),
850
[
851
("LOWER(df1.text) = df2.text", 2), # case conversion
852
("SUBSTR(df1.code, 1, 2) = SUBSTR(df2.code, 1, 2)", 3), # first letter match
853
("LENGTH(df1.text) = LENGTH(df2.text)", 5), # cartesian on matching lengths
854
],
855
)
856
def test_join_on_expression_conditions(expression: str, expected_length: int) -> None:
857
df1 = pl.DataFrame(
858
{
859
"text": ["HELLO", "WORLD", "FOO"],
860
"code": ["ABC", "DEF", "GHI"],
861
}
862
)
863
df2 = pl.DataFrame(
864
{
865
"text": ["hello", "world", "bar"],
866
"code": ["ABX", "DEY", "GHZ"],
867
}
868
)
869
query = f"""
870
SELECT df1.text AS text1, df2.text AS text2
871
FROM df1
872
INNER JOIN df2 ON {expression}
873
ORDER BY text1
874
"""
875
res = pl.sql(query, eager=True)
876
assert len(res) == expected_length
877
878
879
@pytest.mark.parametrize(
880
("df1", "df2", "join_constraint", "select_cols", "expected", "schema"),
881
[
882
(
883
pl.DataFrame(
884
{
885
"category": ["fruit", "fruit", "vegetable"],
886
"name": ["apple", "banana", "carrot"],
887
"code": [1, 2, 3],
888
}
889
),
890
pl.DataFrame(
891
{
892
"category": ["fruit", "fruit", "vegetable"],
893
"type": ["sweet", "tropical", "root"],
894
"code_doubled": [2, 4, 6],
895
}
896
),
897
"df1.category = df2.category AND (df1.code * 2) = df2.code_doubled",
898
"df1.name, df1.code, df2.type",
899
[("apple", 1, "sweet"), ("banana", 2, "tropical"), ("carrot", 3, "root")],
900
["name", "code", "type"],
901
),
902
(
903
pl.DataFrame({"id": [1, 2, 3], "name": ["ALICE", "BOB", "CHARLIE"]}),
904
pl.DataFrame({"id": [1, 2, 3], "match": ["alice", "bob", "charlie"]}),
905
"df1.id = df2.id AND LOWER(df1.name) = df2.match",
906
"df1.id, df1.name, df2.match",
907
[(1, "ALICE", "alice"), (2, "BOB", "bob"), (3, "CHARLIE", "charlie")],
908
["id", "name", "match"],
909
),
910
(
911
pl.DataFrame({"x": [2, 4, 6], "y": [1, 2, 3]}),
912
pl.DataFrame({"a": [4, 8, 12], "b": [1, 2, 3]}),
913
"df1.x * 2 = df2.a AND df1.y = df2.b",
914
"df1.x, df1.y, df2.a",
915
[(2, 1, 4), (4, 2, 8), (6, 3, 12)],
916
["x", "y", "a"],
917
),
918
],
919
)
920
def test_join_on_mixed_expression_conditions(
921
df1: pl.DataFrame,
922
df2: pl.DataFrame,
923
join_constraint: str,
924
select_cols: str,
925
expected: list[tuple[Any, ...]],
926
schema: list[str],
927
) -> None:
928
query = f"""
929
SELECT {select_cols}
930
FROM df1
931
INNER JOIN df2 ON {join_constraint}
932
ORDER BY ALL
933
"""
934
df_expected = pl.DataFrame(expected, schema=schema, orient="row")
935
res = pl.sql(query, eager=True)
936
assert_frame_equal(res, df_expected)
937
938
939
@pytest.mark.parametrize(
940
("df1", "df2", "join_constraint", "expected"),
941
[
942
(
943
pl.DataFrame({"text": [" Hello ", " World ", " Test "]}),
944
pl.DataFrame({"text": ["hello", "world", "other"]}),
945
"LOWER(TRIM(df1.text)) = df2.text",
946
[(" Hello ", "hello"), (" World ", "world")],
947
),
948
(
949
pl.DataFrame({"code": ["PREFIX_A", "SECOND_B", "OTHERS_C"]}),
950
pl.DataFrame({"code": ["prefix", "second", "others"]}),
951
"LOWER(SUBSTR(df1.code,1,6)) = df2.code",
952
[("OTHERS_C", "others"), ("PREFIX_A", "prefix"), ("SECOND_B", "second")],
953
),
954
(
955
pl.DataFrame({"name": ["abc", "abcde", "x"]}),
956
pl.DataFrame({"len": [3, 5, 1]}),
957
"LENGTH(df1.name) = df2.len",
958
[("x", 1), ("abc", 3), ("abcde", 5)],
959
),
960
],
961
)
962
def test_join_on_nested_function_expressions(
963
df1: pl.DataFrame,
964
df2: pl.DataFrame,
965
join_constraint: str,
966
expected: list[tuple[Any, ...]],
967
) -> None:
968
col1 = df1.columns[0]
969
col2 = df2.columns[0]
970
971
query = f"""
972
SELECT df1.{col1} AS col1, df2.{col2} AS col2
973
FROM df1
974
INNER JOIN df2 ON {join_constraint}
975
ORDER BY df2.{col2}
976
"""
977
df_expected = pl.DataFrame(expected, schema=["col1", "col2"], orient="row")
978
res = pl.sql(query, eager=True)
979
assert_frame_equal(res, df_expected)
980
981
982
@pytest.mark.parametrize(
983
("df1", "df2", "join_constraint", "select_cols", "expected", "schema"),
984
[
985
(
986
pl.DataFrame(
987
{"id": [1, 2, 3], "category": ["A", "B", "A"], "multiplier": [2, 3, 4]}
988
),
989
pl.DataFrame(
990
{"id": [1, 2, 3], "base": [5, 15, 20], "category": ["A", "B", "C"]}
991
),
992
"df1.id = df2.id AND df1.multiplier * 5 = df2.base AND df1.category = 'A'",
993
"df1.id, df1.multiplier, df2.base",
994
[(3, 4, 20)],
995
["id", "multiplier", "base"],
996
),
997
(
998
pl.DataFrame({"id": [1, 2, 3], "value": [10, 20, 30]}),
999
pl.DataFrame({"id": [1, 2, 3], "target": [20, 40, 60]}),
1000
"df1.id = df2.id AND (df1.value * 2) = df2.target AND df1.id = 2",
1001
"df1.id, df1.value, df2.target",
1002
[(2, 20, 40)],
1003
["id", "value", "target"],
1004
),
1005
(
1006
pl.DataFrame(
1007
{
1008
"x": [1, 2, 3],
1009
"type": ["A", "B", "A"],
1010
"status": ["active", "inactive", "active"],
1011
}
1012
),
1013
pl.DataFrame({"x": [1, 2, 3], "data": ["foo", "bar", "baz"]}),
1014
"df1.x = df2.x AND df1.type = 'A' AND df1.status = 'active'",
1015
"df1.x, df2.data",
1016
[(1, "foo"), (3, "baz")],
1017
["x", "data"],
1018
),
1019
],
1020
)
1021
def test_join_on_expression_with_literals(
1022
df1: pl.DataFrame,
1023
df2: pl.DataFrame,
1024
join_constraint: str,
1025
select_cols: str,
1026
expected: list[tuple[Any, ...]],
1027
schema: list[str],
1028
) -> None:
1029
query = f"""
1030
SELECT {select_cols}
1031
FROM df1
1032
INNER JOIN df2 ON {join_constraint}
1033
ORDER BY ALL
1034
"""
1035
df_expected = pl.DataFrame(
1036
expected,
1037
schema=schema,
1038
orient="row",
1039
)
1040
res = pl.sql(query, eager=True)
1041
assert_frame_equal(res, df_expected)
1042
1043
1044
@pytest.mark.parametrize(
1045
("df1", "df2", "join_constraint", "reversed_join_constraint", "expected", "schema"),
1046
[
1047
(
1048
pl.DataFrame({"id": [1, 2, 3], "val": ["a", "b", "c"]}),
1049
pl.DataFrame({"id": [2, 3, 4], "val": ["x", "y", "z"]}),
1050
"df1.id = df2.id",
1051
"df2.id = df1.id",
1052
[(2, "b", "x"), (3, "c", "y")],
1053
["id", "val1", "val2"],
1054
),
1055
(
1056
pl.DataFrame({"x": [1, 2, 3]}),
1057
pl.DataFrame({"y": [2, 4, 6]}),
1058
"df1.x * 2 = df2.y",
1059
"df2.y = (df1.x * 2)",
1060
[(1, 2), (2, 4), (3, 6)],
1061
["x", "y"],
1062
),
1063
(
1064
pl.DataFrame({"a": [5, 10, 15]}),
1065
pl.DataFrame({"b": [10, 20, 30]}),
1066
"(df1.a + df1.a) = df2.b",
1067
"df2.b = (df1.a + df1.a)",
1068
[(5, 10), (10, 20), (15, 30)],
1069
["a", "b"],
1070
),
1071
],
1072
)
1073
def test_join_on_reversed_constraint_order(
1074
df1: pl.DataFrame,
1075
df2: pl.DataFrame,
1076
join_constraint: str,
1077
reversed_join_constraint: str,
1078
expected: list[tuple[Any, ...]],
1079
schema: list[str],
1080
) -> None:
1081
select_cols = (
1082
"df1.id, df1.val AS val1, df2.val AS val2"
1083
if len(schema) == 3
1084
else ", ".join(f"df{i + 1}.{col}" for i, col in enumerate(schema))
1085
)
1086
df_expected = pl.DataFrame(
1087
expected,
1088
schema=schema,
1089
orient="row",
1090
)
1091
for constraint in (join_constraint, reversed_join_constraint):
1092
res = pl.sql(
1093
query=f"""
1094
SELECT {select_cols}
1095
FROM df1
1096
INNER JOIN df2 ON {constraint}
1097
ORDER BY ALL
1098
""",
1099
eager=True,
1100
)
1101
assert_frame_equal(res, df_expected)
1102
1103
1104
@pytest.mark.parametrize(
1105
("df1", "df2", "join_constraint", "expected", "schema"),
1106
[
1107
(
1108
pl.DataFrame({"a": [1, 2, 3]}),
1109
pl.DataFrame({"b": [2, 4, 6]}),
1110
"a * 2 = b",
1111
[(1, 2), (2, 4), (3, 6)],
1112
["a", "b"],
1113
),
1114
(
1115
pl.DataFrame({"x": [5, 10, 15], "y": [3, 5, 7]}),
1116
pl.DataFrame({"sum": [8, 15, 22]}),
1117
"x + y = sum",
1118
[(5, 3, 8), (10, 5, 15), (15, 7, 22)],
1119
["x", "y", "sum"],
1120
),
1121
(
1122
pl.DataFrame({"name": ["abc", "hello", "test"]}),
1123
pl.DataFrame({"len": [3, 5, 4]}),
1124
"LENGTH(name) = len",
1125
[("abc", 3), ("hello", 5), ("test", 4)],
1126
["name", "len"],
1127
),
1128
],
1129
)
1130
def test_join_on_unqualified_expressions(
1131
df1: pl.DataFrame,
1132
df2: pl.DataFrame,
1133
join_constraint: str,
1134
expected: list[tuple[Any, ...]],
1135
schema: list[str],
1136
) -> None:
1137
df1_cols = ", ".join(f"df1.{col}" for col in df1.columns)
1138
df2_cols = ", ".join(f"df2.{col}" for col in df2.columns)
1139
1140
query = f"""
1141
SELECT {df1_cols}, {df2_cols}
1142
FROM df1
1143
INNER JOIN df2 ON {join_constraint}
1144
ORDER BY ALL
1145
"""
1146
df_expected = pl.DataFrame(
1147
expected,
1148
schema=schema,
1149
orient="row",
1150
)
1151
res = pl.sql(query, eager=True)
1152
assert_frame_equal(res, df_expected)
1153
1154
1155
def test_multiway_join_chain_with_aliased_cols() -> None:
1156
# tracking/resolving constraints for 3-way (or more) joins can be... "fun" :)
1157
# ref: https://github.com/pola-rs/polars/issues/25126
1158
1159
df1 = pl.DataFrame({"a": [111, 222], "x1": ["df1", "df1"]})
1160
df2 = pl.DataFrame({"a": [333, 111], "b": [444, 222], "x2": ["df2", "df2"]})
1161
df3 = pl.DataFrame({"a": [222, 111], "x3": ["df3", "df3"]})
1162
1163
for query, expected_cols, expected_row in (
1164
(
1165
# three-way join where "a" exists in all three frames (df1, df2, df3)
1166
"""
1167
SELECT * FROM df3
1168
INNER JOIN df2 ON df2.b = df3.a
1169
INNER JOIN df1 ON df1.a = df2.a
1170
""",
1171
["a", "x3", "a:df2", "b", "x2", "a:df1", "x1"],
1172
(222, "df3", 111, 222, "df2", 111, "df1"),
1173
),
1174
(
1175
# almost the same, but the final constraint on "a" refers back to df1
1176
"""
1177
SELECT * FROM df3
1178
INNER JOIN df2 ON df2.b = df3.a
1179
INNER JOIN df1 ON df1.a = df3.a
1180
""",
1181
["a", "x3", "a:df2", "b", "x2", "a:df1", "x1"],
1182
(222, "df3", 111, 222, "df2", 222, "df1"),
1183
),
1184
):
1185
res = pl.sql(query, eager=True)
1186
1187
assert res.height == 1
1188
assert res.columns == expected_cols
1189
assert res.row(0) == expected_row
1190
1191
1192
@pytest.mark.parametrize(
1193
("join_condition", "expected_error"),
1194
[
1195
(
1196
"(df1.id + df2.val) = df2.id",
1197
r"unsupported join condition: left side references both 'df1' and 'df2'",
1198
),
1199
(
1200
"df1.id = (df2.id + df1.val)",
1201
r"unsupported join condition: right side references both 'df1' and 'df2'",
1202
),
1203
],
1204
)
1205
def test_unsupported_join_conditions(join_condition: str, expected_error: str) -> None:
1206
# note: this is technically valid (if unusual) SQL, but we don't support it
1207
df1 = pl.DataFrame({"id": [1, 2, 3], "val": [10, 20, 30]})
1208
df2 = pl.DataFrame({"id": [2, 3, 4], "val": [20, 30, 40]})
1209
1210
with pytest.raises(SQLInterfaceError, match=expected_error):
1211
pl.sql(f"SELECT * FROM df1 INNER JOIN df2 ON {join_condition}")
1212
1213
1214
def test_ambiguous_column_detection_in_joins() -> None:
1215
# unqualified column references that exist in multiple tables should raise
1216
# an error (with a helpful suggestion about qualifying the reference)
1217
with pytest.raises(
1218
SQLInterfaceError,
1219
match=r'ambiguous reference to column "k" \(use one of: a\.k, c\.k\)',
1220
):
1221
pl.sql(
1222
query="""
1223
WITH
1224
a AS (SELECT 0 AS k),
1225
c AS (SELECT 0 AS k)
1226
SELECT k FROM a JOIN c ON a.k = c.k
1227
""",
1228
eager=True,
1229
)
1230
1231
1232
def test_duplicate_column_detection_via_wildcard() -> None:
1233
# selecting a column explicitly that is already included in a qualified
1234
# wildcard from the same table should raise a duplicate column error
1235
a = pl.DataFrame({"id": [1, 2], "x": [10, 20]})
1236
b = pl.DataFrame({"id": [1, 2], "y": [30, 40]})
1237
1238
with pytest.raises(
1239
SQLInterfaceError,
1240
match=r"column 'id' is duplicated in the SELECT",
1241
):
1242
pl.sql("SELECT a.*, a.id FROM a JOIN b ON a.id = b.id", eager=True)
1243
1244
1245
def test_qualified_wildcard_multiway_join() -> None:
1246
df1 = pl.DataFrame({"id": [1, 2], "a": ["x", "y"]})
1247
df2 = pl.DataFrame({"id": [1, 2], "b": ["p", "q"]})
1248
df3 = pl.DataFrame({"id": [1, 2], "c": ["m", "n"]})
1249
1250
res = pl.sql("""
1251
SELECT df1.*, df2.*, df3.*
1252
FROM df1
1253
INNER JOIN df2 ON df1.id = df2.id
1254
INNER JOIN df3 ON df1.id = df3.id
1255
ORDER BY id
1256
""").collect()
1257
expected = pl.DataFrame(
1258
{
1259
"id": [1, 2],
1260
"a": ["x", "y"],
1261
"id:df2": [1, 2],
1262
"b": ["p", "q"],
1263
"id:df3": [1, 2],
1264
"c": ["m", "n"],
1265
}
1266
)
1267
assert_frame_equal(res, expected)
1268
1269
1270
def test_qualified_wildcard_self_join() -> None:
1271
df = pl.DataFrame(
1272
{
1273
"id": [1, 2, 3],
1274
"parent": [None, 1, 1],
1275
"name": ["root", "child1", "child2"],
1276
}
1277
)
1278
res = pl.sql("""
1279
SELECT child.*, parent.*
1280
FROM df AS child
1281
LEFT JOIN df AS parent ON child.parent = parent.id
1282
ORDER BY id
1283
""").collect()
1284
1285
expected = pl.DataFrame(
1286
{
1287
"id": [1, 2, 3],
1288
"parent": [None, 1, 1],
1289
"name": ["root", "child1", "child2"],
1290
"id:parent": [None, 1, 1],
1291
"parent:parent": [None, None, None],
1292
"name:parent": [None, "root", "root"],
1293
},
1294
schema_overrides={"parent:parent": pl.Int64},
1295
)
1296
assert_frame_equal(res, expected)
1297
1298
1299
@pytest.mark.parametrize(
1300
("join_type", "result"),
1301
[
1302
(
1303
"INNER",
1304
{"k": [1], "v": ["a"], "k:df2": [1], "v:df2": ["x"]},
1305
),
1306
(
1307
"LEFT",
1308
{"k": [1, 2], "v": ["a", "b"], "k:df2": [1, None], "v:df2": ["x", None]},
1309
),
1310
(
1311
"RIGHT",
1312
{"k": [1, None], "v": ["a", None], "k:df2": [1, 3], "v:df2": ["x", "y"]},
1313
),
1314
],
1315
)
1316
def test_qualified_wildcard_join_types(join_type: str, result: dict[str, Any]) -> None:
1317
df1 = pl.DataFrame({"k": [1, 2], "v": ["a", "b"]})
1318
df2 = pl.DataFrame({"k": [1, 3], "v": ["x", "y"]})
1319
1320
actual = pl.sql(
1321
query=f"""
1322
SELECT df1.*, df2.*
1323
FROM df1 {join_type} JOIN df2 ON df1.k = df2.k
1324
""",
1325
eager=True,
1326
)
1327
expected = pl.DataFrame(result)
1328
assert_frame_equal(
1329
left=expected,
1330
right=actual,
1331
check_row_order=False,
1332
)
1333
1334
1335
@pytest.mark.parametrize(
1336
("query", "expected"),
1337
[
1338
( # specific column conflicts with wildcard
1339
"SELECT a.id, b.* FROM a JOIN b ON a.id = b.id",
1340
{"id": [1, 2], "id:b": [1, 2], "y": [30, 40]},
1341
),
1342
( # specific column doesn't conflict with wildcard
1343
"SELECT b.y, a.* FROM a JOIN b ON a.id = b.id",
1344
{"y": [30, 40], "id": [1, 2], "x": [10, 20]},
1345
),
1346
( # single-table wildcard (no conflict, uses original names)
1347
"SELECT b.* FROM a JOIN b ON a.id = b.id",
1348
{"id": [1, 2], "y": [30, 40]},
1349
),
1350
( # table aliases (disambiguation should use the alias)
1351
"SELECT t1.*, t2.* FROM a AS t1 JOIN b AS t2 ON t1.id = t2.id",
1352
{"id": [1, 2], "x": [10, 20], "id:t2": [1, 2], "y": [30, 40]},
1353
),
1354
( # no column overlap (expect no disambiguation)
1355
"SELECT a.*, c.* FROM a JOIN c ON a.id = c.k",
1356
{"id": [1, 2], "x": [10, 20], "k": [1, 2], "z": [50, 60]},
1357
),
1358
( # reverse wildcard order (disambiguation follows *table* order)
1359
"SELECT b.*, a.* FROM a JOIN b ON a.id = b.id",
1360
{"id:b": [1, 2], "y": [30, 40], "id": [1, 2], "x": [10, 20]},
1361
),
1362
],
1363
)
1364
def test_qualified_wildcard_combinations(query: str, expected: dict[str, Any]) -> None:
1365
a = pl.DataFrame({"id": [1, 2], "x": [10, 20]})
1366
b = pl.DataFrame({"id": [1, 2], "y": [30, 40]})
1367
c = pl.DataFrame({"k": [1, 2], "z": [50, 60]})
1368
1369
assert_frame_equal(
1370
left=pl.DataFrame(expected),
1371
right=pl.sql(query).collect(),
1372
check_row_order=False,
1373
)
1374
1375