Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_multiscan.py
6939 views
1
from __future__ import annotations
2
3
import io
4
from functools import partial
5
from typing import IO, TYPE_CHECKING, Any, Callable
6
7
import pyarrow.parquet as pq
8
import pytest
9
from hypothesis import given
10
from hypothesis import strategies as st
11
12
import polars as pl
13
from polars.meta.index_type import get_index_type
14
from polars.testing import assert_frame_equal
15
16
if TYPE_CHECKING:
17
from pathlib import Path
18
19
SCAN_AND_WRITE_FUNCS = [
20
(pl.scan_ipc, pl.DataFrame.write_ipc),
21
(pl.scan_parquet, pl.DataFrame.write_parquet),
22
(pl.scan_csv, pl.DataFrame.write_csv),
23
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
24
]
25
26
27
@pytest.mark.write_disk
28
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
29
def test_include_file_paths(tmp_path: Path, scan: Any, write: Any) -> None:
30
a_path = tmp_path / "a"
31
b_path = tmp_path / "b"
32
33
write(pl.DataFrame({"a": [5, 10]}), a_path)
34
write(pl.DataFrame({"a": [1996]}), b_path)
35
36
out = scan([a_path, b_path], include_file_paths="f")
37
38
assert_frame_equal(
39
out.collect(),
40
pl.DataFrame(
41
{
42
"a": [5, 10, 1996],
43
"f": [str(a_path), str(a_path), str(b_path)],
44
}
45
),
46
)
47
48
49
@pytest.mark.parametrize(
50
("scan", "write", "ext", "supports_missing_columns", "supports_hive_partitioning"),
51
[
52
(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc", False, True),
53
(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet", True, True),
54
(pl.scan_csv, pl.DataFrame.write_csv, "csv", False, False),
55
(pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl", False, False),
56
],
57
)
58
@pytest.mark.parametrize("missing_column", [False, True])
59
@pytest.mark.parametrize("row_index", [False, True])
60
@pytest.mark.parametrize("include_file_paths", [False, True])
61
@pytest.mark.parametrize("hive", [False, True])
62
@pytest.mark.parametrize("col", [False, True])
63
@pytest.mark.write_disk
64
def test_multiscan_projection(
65
tmp_path: Path,
66
scan: Callable[..., pl.LazyFrame],
67
write: Callable[[pl.DataFrame, Path], Any],
68
ext: str,
69
supports_missing_columns: bool,
70
supports_hive_partitioning: bool,
71
missing_column: bool,
72
row_index: bool,
73
include_file_paths: bool,
74
hive: bool,
75
col: bool,
76
) -> None:
77
a = pl.DataFrame({"col": [5, 10, 1996]})
78
b = pl.DataFrame({"col": [13, 37]})
79
80
if missing_column and supports_missing_columns:
81
a = a.with_columns(missing=pl.Series([420, 2000, 9]))
82
83
a_path: Path
84
b_path: Path
85
multiscan_path: Path
86
87
if hive and supports_hive_partitioning:
88
(tmp_path / "hive_col=0").mkdir()
89
a_path = tmp_path / "hive_col=0" / f"a.{ext}"
90
(tmp_path / "hive_col=1").mkdir()
91
b_path = tmp_path / "hive_col=1" / f"b.{ext}"
92
93
multiscan_path = tmp_path
94
95
else:
96
a_path = tmp_path / f"a.{ext}"
97
b_path = tmp_path / f"b.{ext}"
98
99
multiscan_path = tmp_path / f"*.{ext}"
100
101
write(a, a_path)
102
write(b, b_path)
103
104
base_projection = []
105
if missing_column and supports_missing_columns:
106
base_projection += ["missing"]
107
if row_index:
108
base_projection += ["row_index"]
109
if include_file_paths:
110
base_projection += ["file_path"]
111
if hive and supports_hive_partitioning:
112
base_projection += ["hive_col"]
113
if col:
114
base_projection += ["col"]
115
116
ifp = "file_path" if include_file_paths else None
117
ri = "row_index" if row_index else None
118
119
args = {
120
"missing_columns": "insert" if missing_column else "raise",
121
"include_file_paths": ifp,
122
"row_index_name": ri,
123
"hive_partitioning": hive,
124
}
125
126
if not supports_missing_columns:
127
del args["missing_columns"]
128
if not supports_hive_partitioning:
129
del args["hive_partitioning"]
130
131
for projection in [
132
base_projection,
133
base_projection[::-1],
134
]:
135
assert_frame_equal(
136
scan(multiscan_path, **args).collect(engine="streaming").select(projection),
137
scan(multiscan_path, **args).select(projection).collect(engine="streaming"),
138
)
139
140
for remove in range(len(base_projection)):
141
new_projection = base_projection.copy()
142
new_projection.pop(remove)
143
144
for projection in [
145
new_projection,
146
new_projection[::-1],
147
]:
148
print(projection)
149
assert_frame_equal(
150
scan(multiscan_path, **args)
151
.collect(engine="streaming")
152
.select(projection),
153
scan(multiscan_path, **args)
154
.select(projection)
155
.collect(engine="streaming"),
156
)
157
158
159
@pytest.mark.parametrize(
160
("scan", "write", "ext"),
161
[
162
(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"),
163
(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"),
164
],
165
)
166
@pytest.mark.write_disk
167
def test_multiscan_hive_predicate(
168
tmp_path: Path,
169
scan: Callable[..., pl.LazyFrame],
170
write: Callable[[pl.DataFrame, Path], Any],
171
ext: str,
172
) -> None:
173
a = pl.DataFrame({"col": [5, 10, 1996]})
174
b = pl.DataFrame({"col": [13, 37]})
175
c = pl.DataFrame({"col": [3, 5, 2024]})
176
177
(tmp_path / "hive_col=0").mkdir()
178
a_path = tmp_path / "hive_col=0" / f"0.{ext}"
179
(tmp_path / "hive_col=1").mkdir()
180
b_path = tmp_path / "hive_col=1" / f"0.{ext}"
181
(tmp_path / "hive_col=2").mkdir()
182
c_path = tmp_path / "hive_col=2" / f"0.{ext}"
183
184
multiscan_path = tmp_path
185
186
write(a, a_path)
187
write(b, b_path)
188
write(c, c_path)
189
190
full = scan(multiscan_path).collect(engine="streaming")
191
full_ri = full.with_row_index("ri", 42)
192
193
last_pred = None
194
try:
195
for pred in [
196
pl.col.hive_col == 0,
197
pl.col.hive_col == 1,
198
pl.col.hive_col == 2,
199
pl.col.hive_col < 2,
200
pl.col.hive_col > 0,
201
pl.col.hive_col != 1,
202
pl.col.hive_col != 3,
203
pl.col.col == 13,
204
pl.col.col != 13,
205
(pl.col.col != 13) & (pl.col.hive_col == 1),
206
(pl.col.col != 13) & (pl.col.hive_col != 1),
207
]:
208
last_pred = pred
209
assert_frame_equal(
210
full.filter(pred),
211
scan(multiscan_path).filter(pred).collect(engine="streaming"),
212
)
213
214
assert_frame_equal(
215
full_ri.filter(pred),
216
scan(multiscan_path)
217
.with_row_index("ri", 42)
218
.filter(pred)
219
.collect(engine="streaming"),
220
)
221
except Exception as _:
222
print(last_pred)
223
raise
224
225
226
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
227
@pytest.mark.write_disk
228
def test_multiscan_row_index(
229
tmp_path: Path,
230
scan: Callable[..., pl.LazyFrame],
231
write: Callable[[pl.DataFrame, Path], Any],
232
) -> None:
233
a = pl.DataFrame({"col": [5, 10, 1996]})
234
b = pl.DataFrame({"col": [42]})
235
c = pl.DataFrame({"col": [13, 37]})
236
237
write(a, tmp_path / "a")
238
write(b, tmp_path / "b")
239
write(c, tmp_path / "c")
240
241
col = pl.concat([a, b, c]).to_series()
242
g = tmp_path / "*"
243
244
assert_frame_equal(
245
scan(g, row_index_name="ri").collect(),
246
pl.DataFrame(
247
[
248
pl.Series("ri", range(6), get_index_type()),
249
col,
250
]
251
),
252
)
253
254
start = 42
255
assert_frame_equal(
256
scan(g, row_index_name="ri", row_index_offset=start).collect(),
257
pl.DataFrame(
258
[
259
pl.Series("ri", range(start, start + 6), get_index_type()),
260
col,
261
]
262
),
263
)
264
265
start = 42
266
assert_frame_equal(
267
scan(g, row_index_name="ri", row_index_offset=start).slice(3, 3).collect(),
268
pl.DataFrame(
269
[
270
pl.Series("ri", range(start + 3, start + 6), get_index_type()),
271
col.slice(3, 3),
272
]
273
),
274
)
275
276
start = 42
277
assert_frame_equal(
278
scan(g, row_index_name="ri", row_index_offset=start)
279
.filter(pl.col("col") < 15)
280
.collect(),
281
pl.DataFrame(
282
[
283
pl.Series("ri", [start + 0, start + 1, start + 4], get_index_type()),
284
pl.Series("col", [5, 10, 13]),
285
]
286
),
287
)
288
289
290
@pytest.mark.parametrize(
291
("scan", "write", "ext"),
292
[
293
(pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"),
294
(pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"),
295
pytest.param(
296
pl.scan_csv,
297
pl.DataFrame.write_csv,
298
"csv",
299
marks=pytest.mark.xfail(
300
reason="See https://github.com/pola-rs/polars/issues/21211"
301
),
302
),
303
(pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"),
304
],
305
)
306
@pytest.mark.write_disk
307
def test_schema_mismatch_type_mismatch(
308
tmp_path: Path,
309
scan: Callable[..., pl.LazyFrame],
310
write: Callable[[pl.DataFrame, Path], Any],
311
ext: str,
312
) -> None:
313
a = pl.DataFrame({"xyz_col": [5, 10, 1996]})
314
b = pl.DataFrame({"xyz_col": ["a", "b", "c"]})
315
316
a_path = tmp_path / f"a.{ext}"
317
b_path = tmp_path / f"b.{ext}"
318
319
multiscan_path = tmp_path / f"*.{ext}"
320
321
write(a, a_path)
322
write(b, b_path)
323
324
q = scan(multiscan_path)
325
326
# NDJSON will just parse according to `projected_schema`
327
cx = (
328
pytest.raises(pl.exceptions.ComputeError, match="cannot parse 'a' as Int64")
329
if scan is pl.scan_ndjson
330
else pytest.raises(
331
pl.exceptions.SchemaError, # type: ignore[arg-type]
332
match=(
333
"data type mismatch for column xyz_col: "
334
"incoming: String != target: Int64"
335
),
336
)
337
)
338
339
with cx:
340
q.collect(engine="streaming")
341
342
343
@pytest.mark.parametrize(
344
("scan", "write", "ext"),
345
[
346
# (pl.scan_parquet, pl.DataFrame.write_parquet, "parquet"), # TODO: _
347
# (pl.scan_ipc, pl.DataFrame.write_ipc, "ipc"), # TODO: _
348
pytest.param(
349
pl.scan_csv,
350
pl.DataFrame.write_csv,
351
"csv",
352
marks=pytest.mark.xfail(
353
reason="See https://github.com/pola-rs/polars/issues/21211"
354
),
355
),
356
# (pl.scan_ndjson, pl.DataFrame.write_ndjson, "jsonl"), # TODO: _
357
],
358
)
359
@pytest.mark.write_disk
360
def test_schema_mismatch_order_mismatch(
361
tmp_path: Path,
362
scan: Callable[..., pl.LazyFrame],
363
write: Callable[[pl.DataFrame, Path], Any],
364
ext: str,
365
) -> None:
366
a = pl.DataFrame({"x": [5, 10, 1996], "y": ["a", "b", "c"]})
367
b = pl.DataFrame({"y": ["x", "y"], "x": [1, 2]})
368
369
a_path = tmp_path / f"a.{ext}"
370
b_path = tmp_path / f"b.{ext}"
371
372
multiscan_path = tmp_path / f"*.{ext}"
373
374
write(a, a_path)
375
write(b, b_path)
376
377
q = scan(multiscan_path)
378
379
with pytest.raises(pl.exceptions.SchemaError):
380
q.collect(engine="streaming")
381
382
383
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
384
def test_multiscan_head(
385
scan: Callable[..., pl.LazyFrame],
386
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
387
) -> None:
388
a = io.BytesIO()
389
b = io.BytesIO()
390
for f in [a, b]:
391
write(pl.Series("c1", range(10)).to_frame(), f)
392
f.seek(0)
393
394
assert_frame_equal(
395
scan([a, b]).head(5).collect(engine="streaming"),
396
pl.Series("c1", range(5)).to_frame(),
397
)
398
399
400
@pytest.mark.parametrize(
401
("scan", "write"),
402
[
403
(pl.scan_ipc, pl.DataFrame.write_ipc),
404
(pl.scan_parquet, pl.DataFrame.write_parquet),
405
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
406
(
407
pl.scan_csv,
408
pl.DataFrame.write_csv,
409
),
410
],
411
)
412
def test_multiscan_tail(
413
scan: Callable[..., pl.LazyFrame],
414
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
415
) -> None:
416
a = io.BytesIO()
417
b = io.BytesIO()
418
for f in [a, b]:
419
write(pl.Series("c1", range(10)).to_frame(), f)
420
f.seek(0)
421
422
assert_frame_equal(
423
scan([a, b]).tail(5).collect(engine="streaming"),
424
pl.Series("c1", range(5, 10)).to_frame(),
425
)
426
427
428
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
429
def test_multiscan_slice_middle(
430
scan: Callable[..., pl.LazyFrame],
431
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
432
) -> None:
433
fs = [io.BytesIO() for _ in range(13)]
434
for f in fs:
435
write(pl.Series("c1", range(7)).to_frame(), f)
436
f.seek(0)
437
438
offset = 5 * 7 - 5
439
expected = (
440
list(range(2, 7)) # fs[4]
441
+ list(range(7)) # fs[5]
442
+ list(range(5)) # fs[6]
443
)
444
expected_series = [pl.Series("c1", expected)]
445
ri_expected_series = [
446
pl.Series("ri", range(offset, offset + 17), get_index_type())
447
] + expected_series
448
449
assert_frame_equal(
450
scan(fs).slice(offset, 17).collect(engine="streaming"),
451
pl.DataFrame(expected_series),
452
)
453
assert_frame_equal(
454
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),
455
pl.DataFrame(ri_expected_series),
456
)
457
458
# Negative slices
459
offset = -(13 * 7 - offset)
460
assert_frame_equal(
461
scan(fs).slice(offset, 17).collect(engine="streaming"),
462
pl.DataFrame(expected_series),
463
)
464
assert_frame_equal(
465
scan(fs, row_index_name="ri").slice(offset, 17).collect(engine="streaming"),
466
pl.DataFrame(ri_expected_series),
467
)
468
469
470
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
471
@given(offset=st.integers(-100, 100), length=st.integers(0, 101))
472
def test_multiscan_slice_parametric(
473
scan: Callable[..., pl.LazyFrame],
474
write: Callable[[pl.DataFrame, io.BytesIO | Path], Any],
475
offset: int,
476
length: int,
477
) -> None:
478
ref = io.BytesIO()
479
write(pl.Series("c1", [i % 7 for i in range(13 * 7)]).to_frame(), ref)
480
ref.seek(0)
481
482
fs = [io.BytesIO() for _ in range(13)]
483
for f in fs:
484
write(pl.Series("c1", range(7)).to_frame(), f)
485
f.seek(0)
486
487
assert_frame_equal(
488
scan(ref).slice(offset, length).collect(),
489
scan(fs).slice(offset, length).collect(engine="streaming"),
490
)
491
492
ref.seek(0)
493
for f in fs:
494
f.seek(0)
495
496
assert_frame_equal(
497
scan(ref, row_index_name="ri", row_index_offset=42)
498
.slice(offset, length)
499
.collect(),
500
scan(fs, row_index_name="ri", row_index_offset=42)
501
.slice(offset, length)
502
.collect(engine="streaming"),
503
)
504
505
assert_frame_equal(
506
scan(ref, row_index_name="ri", row_index_offset=42)
507
.slice(offset, length)
508
.select("ri")
509
.collect(),
510
scan(fs, row_index_name="ri", row_index_offset=42)
511
.slice(offset, length)
512
.select("ri")
513
.collect(engine="streaming"),
514
)
515
516
517
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
518
def test_many_files(scan: Any, write: Any) -> None:
519
f = io.BytesIO()
520
write(pl.DataFrame({"a": [5, 10, 1996]}), f)
521
bs = f.getvalue()
522
523
out = scan([bs] * 1023)
524
525
assert_frame_equal(
526
out.collect(),
527
pl.DataFrame(
528
{
529
"a": [5, 10, 1996] * 1023,
530
}
531
),
532
)
533
534
535
def test_deadlock_stop_requested(monkeypatch: Any) -> None:
536
df = pl.DataFrame(
537
{
538
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
539
}
540
)
541
542
f = io.BytesIO()
543
df.write_parquet(f, row_group_size=1)
544
545
monkeypatch.setenv("POLARS_MAX_THREADS", "2")
546
monkeypatch.setenv("POLARS_JOIN_SAMPLE_LIMIT", "1")
547
548
left_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]
549
right_fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]
550
551
left = pl.scan_parquet(left_fs) # type: ignore[arg-type]
552
right = pl.scan_parquet(right_fs) # type: ignore[arg-type]
553
554
left.join(right, pl.col.a == pl.col.a).collect(engine="streaming").height
555
556
557
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
558
def test_deadlock_linearize(scan: Any, write: Any) -> None:
559
df = pl.DataFrame(
560
{
561
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
562
}
563
)
564
565
f = io.BytesIO()
566
write(df, f)
567
fs = [io.BytesIO(f.getbuffer()) for _ in range(10)]
568
lf = scan(fs).head(100)
569
570
assert_frame_equal(
571
lf.collect(
572
engine="streaming", optimizations=pl.QueryOptFlags(slice_pushdown=False)
573
),
574
pl.concat([df] * 10),
575
)
576
577
578
@pytest.mark.parametrize(
579
("scan", "write"),
580
SCAN_AND_WRITE_FUNCS,
581
)
582
def test_row_index_filter_22612(scan: Any, write: Any) -> None:
583
df = pl.DataFrame(
584
{
585
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
586
}
587
)
588
589
f = io.BytesIO()
590
591
if write is pl.DataFrame.write_parquet:
592
df.write_parquet(f, row_group_size=5)
593
assert pq.read_metadata(f).num_row_groups == 2
594
else:
595
write(df, f)
596
597
for end in range(2, 10):
598
assert_frame_equal(
599
scan(f)
600
.with_row_index()
601
.filter(pl.col("index") >= end - 2, pl.col("index") <= end)
602
.collect(),
603
df.with_row_index().slice(end - 2, 3),
604
)
605
606
assert_frame_equal(
607
scan(f)
608
.with_row_index()
609
.filter(pl.col("index").is_between(end - 2, end))
610
.collect(),
611
df.with_row_index().slice(end - 2, 3),
612
)
613
614
615
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
616
def test_row_index_name_in_file(scan: Any, write: Any) -> None:
617
f = io.BytesIO()
618
write(pl.DataFrame({"index": 1}), f)
619
620
with pytest.raises(
621
pl.exceptions.DuplicateError,
622
match="cannot add row_index with name 'index': column already exists in file",
623
):
624
scan(f).with_row_index().collect()
625
626
627
def test_extra_columns_not_ignored_22218() -> None:
628
dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "c": 2})]
629
630
files: list[IO[bytes]] = [io.BytesIO(), io.BytesIO()]
631
632
dfs[0].write_parquet(files[0])
633
dfs[1].write_parquet(files[1])
634
635
with pytest.raises(
636
pl.exceptions.SchemaError,
637
match="extra column in file outside of expected schema: c, hint: specify .*or pass",
638
):
639
(pl.scan_parquet(files, missing_columns="insert").select(pl.all()).collect())
640
641
assert_frame_equal(
642
pl.scan_parquet(
643
files,
644
missing_columns="insert",
645
extra_columns="ignore",
646
)
647
.select(pl.all())
648
.collect(),
649
pl.DataFrame({"a": [1, 2], "b": [1, None]}),
650
)
651
652
653
@pytest.mark.parametrize(("scan", "write"), SCAN_AND_WRITE_FUNCS)
654
def test_scan_null_upcast(scan: Any, write: Any) -> None:
655
dfs = [
656
pl.DataFrame({"a": [1, 2, 3]}),
657
pl.select(a=pl.lit(None, dtype=pl.Null)),
658
]
659
660
files = [io.BytesIO(), io.BytesIO()]
661
662
write(dfs[0], files[0])
663
write(dfs[1], files[1])
664
665
# Prevent CSV schema inference from loading as string (it looks at multiple
666
# files).
667
if scan is pl.scan_csv:
668
scan = partial(scan, schema=dfs[0].schema)
669
670
assert_frame_equal(
671
scan(files).collect(),
672
pl.DataFrame({"a": [1, 2, 3, None]}),
673
)
674
675
676
@pytest.mark.parametrize(
677
("scan", "write"),
678
[
679
(pl.scan_ipc, pl.DataFrame.write_ipc),
680
(pl.scan_parquet, pl.DataFrame.write_parquet),
681
(pl.scan_ndjson, pl.DataFrame.write_ndjson),
682
],
683
)
684
def test_scan_null_upcast_to_nested(scan: Any, write: Any) -> None:
685
schema = {"a": pl.List(pl.Struct({"field": pl.Int64}))}
686
687
dfs = [
688
pl.DataFrame(
689
{"a": [[{"field": 1}], [{"field": 2}], []]},
690
schema=schema,
691
),
692
pl.select(a=pl.lit(None, dtype=pl.Null)),
693
]
694
695
files = [io.BytesIO(), io.BytesIO()]
696
697
write(dfs[0], files[0])
698
write(dfs[1], files[1])
699
700
# Prevent CSV schema inference from loading as string (it looks at multiple
701
# files).
702
if scan is pl.scan_csv:
703
scan = partial(scan, schema=schema)
704
705
assert_frame_equal(
706
scan(files).collect(),
707
pl.DataFrame(
708
{"a": [[{"field": 1}], [{"field": 2}], [], None]},
709
schema=schema,
710
),
711
)
712
713