Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_partition.py
6939 views
1
from __future__ import annotations
2
3
import io
4
from typing import TYPE_CHECKING, Any, TypedDict
5
6
import pytest
7
from hypothesis import given
8
9
import polars as pl
10
from polars.io.partition import (
11
PartitionByKey,
12
PartitionMaxSize,
13
PartitionParted,
14
)
15
from polars.testing import assert_frame_equal, assert_series_equal
16
from polars.testing.parametric.strategies import dataframes
17
18
if TYPE_CHECKING:
19
from pathlib import Path
20
21
from polars._typing import EngineType
22
from polars.io.partition import BasePartitionContext, KeyedPartitionContext
23
24
25
class IOType(TypedDict):
26
"""A type of IO."""
27
28
ext: str
29
scan: Any
30
sink: Any
31
32
33
io_types: list[IOType] = [
34
{"ext": "csv", "scan": pl.scan_csv, "sink": pl.LazyFrame.sink_csv},
35
{"ext": "jsonl", "scan": pl.scan_ndjson, "sink": pl.LazyFrame.sink_ndjson},
36
{"ext": "parquet", "scan": pl.scan_parquet, "sink": pl.LazyFrame.sink_parquet},
37
{"ext": "ipc", "scan": pl.scan_ipc, "sink": pl.LazyFrame.sink_ipc},
38
]
39
40
engines: list[EngineType] = [
41
"streaming",
42
"in-memory",
43
]
44
45
46
@pytest.mark.parametrize("io_type", io_types)
47
@pytest.mark.parametrize("engine", engines)
48
@pytest.mark.parametrize("length", [0, 1, 4, 5, 6, 7])
49
@pytest.mark.parametrize("max_size", [1, 2, 3])
50
@pytest.mark.write_disk
51
def test_max_size_partition(
52
tmp_path: Path,
53
io_type: IOType,
54
engine: EngineType,
55
length: int,
56
max_size: int,
57
) -> None:
58
lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()
59
60
(io_type["sink"])(
61
lf,
62
PartitionMaxSize(tmp_path, max_size=max_size),
63
engine=engine,
64
# We need to sync here because platforms do not guarantee that a close on
65
# one thread is immediately visible on another thread.
66
#
67
# "Multithreaded processes and close()"
68
# https://man7.org/linux/man-pages/man2/close.2.html
69
sync_on_close="data",
70
)
71
72
i = 0
73
while length > 0:
74
assert (io_type["scan"])(tmp_path / f"{i:08x}.{io_type['ext']}").select(
75
pl.len()
76
).collect()[0, 0] == min(max_size, length)
77
78
length -= max_size
79
i += 1
80
81
82
@pytest.mark.parametrize("io_type", io_types)
83
@pytest.mark.parametrize("engine", engines)
84
def test_max_size_partition_lambda(
85
tmp_path: Path, io_type: IOType, engine: EngineType
86
) -> None:
87
length = 17
88
max_size = 3
89
lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()
90
91
(io_type["sink"])(
92
lf,
93
PartitionMaxSize(
94
tmp_path,
95
file_path=lambda ctx: ctx.file_path.with_name("abc-" + ctx.file_path.name),
96
max_size=max_size,
97
),
98
engine=engine,
99
# We need to sync here because platforms do not guarantee that a close on
100
# one thread is immediately visible on another thread.
101
#
102
# "Multithreaded processes and close()"
103
# https://man7.org/linux/man-pages/man2/close.2.html
104
sync_on_close="data",
105
)
106
107
i = 0
108
while length > 0:
109
assert (io_type["scan"])(tmp_path / f"abc-{i:08x}.{io_type['ext']}").select(
110
pl.len()
111
).collect()[0, 0] == min(max_size, length)
112
113
length -= max_size
114
i += 1
115
116
117
@pytest.mark.parametrize("io_type", io_types)
118
@pytest.mark.parametrize("engine", engines)
119
@pytest.mark.write_disk
120
def test_partition_by_key(
121
tmp_path: Path,
122
io_type: IOType,
123
engine: EngineType,
124
) -> None:
125
lf = pl.Series("a", [i % 4 for i in range(7)], pl.Int64).to_frame().lazy()
126
127
(io_type["sink"])(
128
lf,
129
PartitionByKey(
130
tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by="a"
131
),
132
engine=engine,
133
# We need to sync here because platforms do not guarantee that a close on
134
# one thread is immediately visible on another thread.
135
#
136
# "Multithreaded processes and close()"
137
# https://man7.org/linux/man-pages/man2/close.2.html
138
sync_on_close="data",
139
)
140
141
assert_series_equal(
142
(io_type["scan"])(tmp_path / f"0.{io_type['ext']}").collect().to_series(),
143
pl.Series("a", [0, 0], pl.Int64),
144
)
145
assert_series_equal(
146
(io_type["scan"])(tmp_path / f"1.{io_type['ext']}").collect().to_series(),
147
pl.Series("a", [1, 1], pl.Int64),
148
)
149
assert_series_equal(
150
(io_type["scan"])(tmp_path / f"2.{io_type['ext']}").collect().to_series(),
151
pl.Series("a", [2, 2], pl.Int64),
152
)
153
assert_series_equal(
154
(io_type["scan"])(tmp_path / f"3.{io_type['ext']}").collect().to_series(),
155
pl.Series("a", [3], pl.Int64),
156
)
157
158
scan_flags = (
159
{"schema": pl.Schema({"a": pl.String()})} if io_type["ext"] == "csv" else {}
160
)
161
162
# Change the datatype.
163
(io_type["sink"])(
164
lf,
165
PartitionByKey(
166
tmp_path,
167
file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}",
168
by=pl.col.a.cast(pl.String()),
169
),
170
engine=engine,
171
sync_on_close="data",
172
)
173
174
assert_series_equal(
175
(io_type["scan"])(tmp_path / f"0.{io_type['ext']}", **scan_flags)
176
.collect()
177
.to_series(),
178
pl.Series("a", ["0", "0"], pl.String),
179
)
180
assert_series_equal(
181
(io_type["scan"])(tmp_path / f"1.{io_type['ext']}", **scan_flags)
182
.collect()
183
.to_series(),
184
pl.Series("a", ["1", "1"], pl.String),
185
)
186
assert_series_equal(
187
(io_type["scan"])(tmp_path / f"2.{io_type['ext']}", **scan_flags)
188
.collect()
189
.to_series(),
190
pl.Series("a", ["2", "2"], pl.String),
191
)
192
assert_series_equal(
193
(io_type["scan"])(tmp_path / f"3.{io_type['ext']}", **scan_flags)
194
.collect()
195
.to_series(),
196
pl.Series("a", ["3"], pl.String),
197
)
198
199
200
@pytest.mark.parametrize("io_type", io_types)
201
@pytest.mark.parametrize("engine", engines)
202
@pytest.mark.write_disk
203
def test_partition_parted(tmp_path: Path, io_type: IOType, engine: EngineType) -> None:
204
s = pl.Series("a", [1, 1, 2, 3, 3, 4, 4, 4, 6], pl.Int64)
205
lf = s.to_frame().lazy()
206
207
(io_type["sink"])(
208
lf,
209
PartitionParted(
210
tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by="a"
211
),
212
engine=engine,
213
# We need to sync here because platforms do not guarantee that a close on
214
# one thread is immediately visible on another thread.
215
#
216
# "Multithreaded processes and close()"
217
# https://man7.org/linux/man-pages/man2/close.2.html
218
sync_on_close="data",
219
)
220
221
rle = s.rle()
222
223
for i, row in enumerate(rle.struct.unnest().rows(named=True)):
224
assert_series_equal(
225
(io_type["scan"])(tmp_path / f"{i}.{io_type['ext']}").collect().to_series(),
226
pl.Series("a", [row["value"]] * row["len"], pl.Int64),
227
)
228
229
scan_flags = (
230
{"schema_overrides": pl.Schema({"a_str": pl.String()})}
231
if io_type["ext"] == "csv"
232
else {}
233
)
234
235
# Change the datatype.
236
(io_type["sink"])(
237
lf,
238
PartitionParted(
239
tmp_path,
240
file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}",
241
by=[pl.col.a, pl.col.a.cast(pl.String()).alias("a_str")],
242
),
243
engine=engine,
244
sync_on_close="data",
245
)
246
247
for i, row in enumerate(rle.struct.unnest().rows(named=True)):
248
assert_frame_equal(
249
(io_type["scan"])(
250
tmp_path / f"{i}.{io_type['ext']}", **scan_flags
251
).collect(),
252
pl.DataFrame(
253
[
254
pl.Series("a", [row["value"]] * row["len"], pl.Int64),
255
pl.Series("a_str", [str(row["value"])] * row["len"], pl.String),
256
]
257
),
258
)
259
260
# No include key.
261
(io_type["sink"])(
262
lf,
263
PartitionParted(
264
tmp_path,
265
file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}",
266
by=[pl.col.a.cast(pl.String()).alias("a_str")],
267
include_key=False,
268
),
269
engine=engine,
270
sync_on_close="data",
271
)
272
273
for i, row in enumerate(rle.struct.unnest().rows(named=True)):
274
assert_series_equal(
275
(io_type["scan"])(tmp_path / f"{i}.{io_type['ext']}").collect().to_series(),
276
pl.Series("a", [row["value"]] * row["len"], pl.Int64),
277
)
278
279
280
# We only deal with self-describing formats
281
@pytest.mark.parametrize("io_type", [io_types[2], io_types[3]])
282
@pytest.mark.parametrize("engine", engines)
283
@pytest.mark.write_disk
284
@given(
285
df=dataframes(
286
min_cols=1,
287
excluded_dtypes=[
288
pl.Decimal, # Bug see: https://github.com/pola-rs/polars/issues/21684
289
pl.Duration, # Bug see: https://github.com/pola-rs/polars/issues/21964
290
pl.Categorical, # We cannot ensure the string cache is properly held.
291
# Generate invalid UTF-8
292
pl.Binary,
293
pl.Struct,
294
pl.Array,
295
pl.List,
296
],
297
)
298
)
299
def test_partition_by_key_parametric(
300
tmp_path_factory: pytest.TempPathFactory,
301
io_type: IOType,
302
engine: EngineType,
303
df: pl.DataFrame,
304
) -> None:
305
col1 = df.columns[0]
306
307
tmp_path = tmp_path_factory.mktemp("data")
308
309
dfs = df.partition_by(col1)
310
(io_type["sink"])(
311
df.lazy(),
312
PartitionByKey(
313
tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by=col1
314
),
315
engine=engine,
316
# We need to sync here because platforms do not guarantee that a close on
317
# one thread is immediately visible on another thread.
318
#
319
# "Multithreaded processes and close()"
320
# https://man7.org/linux/man-pages/man2/close.2.html
321
sync_on_close="data",
322
)
323
324
for i, df in enumerate(dfs):
325
assert_frame_equal(
326
df,
327
(io_type["scan"])(
328
tmp_path / f"{i}.{io_type['ext']}",
329
).collect(),
330
)
331
332
333
def test_max_size_partition_collect_files(tmp_path: Path) -> None:
334
length = 17
335
max_size = 3
336
lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()
337
338
io_type = io_types[0]
339
output_files = []
340
341
def file_path_cb(ctx: BasePartitionContext) -> Path:
342
print(ctx)
343
print(ctx.full_path)
344
output_files.append(ctx.full_path)
345
print(ctx.file_path)
346
return ctx.file_path
347
348
(io_type["sink"])(
349
lf,
350
PartitionMaxSize(tmp_path, file_path=file_path_cb, max_size=max_size),
351
engine="streaming",
352
# We need to sync here because platforms do not guarantee that a close on
353
# one thread is immediately visible on another thread.
354
#
355
# "Multithreaded processes and close()"
356
# https://man7.org/linux/man-pages/man2/close.2.html
357
sync_on_close="data",
358
)
359
360
assert output_files == [tmp_path / f"{i:08x}.{io_type['ext']}" for i in range(6)]
361
362
363
@pytest.mark.parametrize(("io_type"), io_types)
364
@pytest.mark.parametrize("engine", engines)
365
def test_partition_to_memory(io_type: IOType, engine: EngineType) -> None:
366
df = pl.DataFrame(
367
{
368
"a": [5, 10, 1996],
369
}
370
)
371
372
output_files = {}
373
374
def file_path_cb(ctx: BasePartitionContext) -> io.BytesIO:
375
f = io.BytesIO()
376
output_files[ctx.file_path] = f
377
return f
378
379
io_type["sink"](
380
df.lazy(),
381
PartitionMaxSize("", file_path=file_path_cb, max_size=1),
382
engine=engine,
383
)
384
385
assert len(output_files) == df.height
386
for i, (_, value) in enumerate(output_files.items()):
387
value.seek(0)
388
assert_frame_equal(io_type["scan"](value).collect(), df.slice(i, 1))
389
390
391
def test_partition_key_order_22645() -> None:
392
paths = []
393
394
def cb(ctx: KeyedPartitionContext) -> io.BytesIO:
395
paths.append(ctx.file_path.parent)
396
return io.BytesIO() # return an dummy output
397
398
pl.LazyFrame({"a": [1, 2, 3]}).sink_parquet(
399
pl.io.PartitionByKey(
400
"",
401
file_path=cb,
402
by=[pl.col.a.alias("b"), (pl.col.a + 42).alias("c")],
403
),
404
)
405
406
paths.sort()
407
assert [p.parts for p in paths] == [
408
("b=1", "c=43"),
409
("b=2", "c=44"),
410
("b=3", "c=45"),
411
]
412
413
414
@pytest.mark.parametrize(("io_type"), io_types)
415
@pytest.mark.parametrize("engine", engines)
416
@pytest.mark.parametrize(
417
("df", "sorts"),
418
[
419
(pl.DataFrame({"a": [2, 1, 0, 4, 3, 5, 7, 8, 9]}), "a"),
420
(
421
pl.DataFrame(
422
{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}
423
),
424
"a",
425
),
426
(
427
pl.DataFrame(
428
{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}
429
),
430
["a", "b"],
431
),
432
(
433
pl.DataFrame(
434
{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}
435
),
436
"b",
437
),
438
(
439
pl.DataFrame(
440
{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}
441
),
442
pl.col.a - pl.col.b.str.slice(1).cast(pl.Int64),
443
),
444
],
445
)
446
def test_partition_to_memory_sort_by(
447
io_type: IOType,
448
engine: EngineType,
449
df: pl.DataFrame,
450
sorts: str | pl.Expr | list[str | pl.Expr],
451
) -> None:
452
output_files = {}
453
454
def file_path_cb(ctx: BasePartitionContext) -> io.BytesIO:
455
f = io.BytesIO()
456
output_files[ctx.file_path] = f
457
return f
458
459
io_type["sink"](
460
df.lazy(),
461
PartitionMaxSize(
462
"", file_path=file_path_cb, max_size=3, per_partition_sort_by=sorts
463
),
464
engine=engine,
465
)
466
467
assert len(output_files) == df.height / 3
468
for i, (_, value) in enumerate(output_files.items()):
469
value.seek(0)
470
assert_frame_equal(
471
io_type["scan"](value).collect(), df.slice(i * 3, 3).sort(sorts)
472
)
473
474
475
@pytest.mark.parametrize(("io_type"), io_types)
476
@pytest.mark.parametrize("engine", engines)
477
def test_partition_to_memory_finish_callback(
478
io_type: IOType, engine: EngineType
479
) -> None:
480
df = pl.DataFrame(
481
{
482
"a": [5, 10, 1996],
483
}
484
)
485
486
output_files = {}
487
488
def file_path_cb(ctx: BasePartitionContext) -> io.BytesIO:
489
f = io.BytesIO()
490
output_files[ctx.file_path] = f
491
return f
492
493
num_calls = 0
494
495
def finish_callback(df: pl.DataFrame) -> None:
496
nonlocal num_calls
497
num_calls += 1
498
499
if io_type["ext"] == "parquet":
500
assert df.height == 3
501
502
io_type["sink"](
503
df.lazy(),
504
PartitionMaxSize(
505
"", file_path=file_path_cb, max_size=1, finish_callback=finish_callback
506
),
507
engine=engine,
508
)
509
assert num_calls == 1
510
511
with pytest.raises(FileNotFoundError):
512
io_type["sink"](
513
df.lazy(),
514
PartitionMaxSize(
515
"/path/to/non-existent-paths",
516
max_size=1,
517
finish_callback=finish_callback,
518
),
519
)
520
assert num_calls == 1 # Should not get called here
521
522
523
def test_finish_callback_nested_23306() -> None:
524
data = [{"a": "foo", "b": "bar", "c": ["hello", "ciao", "hola", "bonjour"]}]
525
526
lf = pl.LazyFrame(data)
527
528
def finish_callback(df: None | pl.DataFrame = None) -> None:
529
assert df is not None
530
assert df.height == 1
531
532
partitioning = pl.PartitionByKey(
533
"/",
534
file_path=lambda _: io.BytesIO(),
535
by=["a", "b"],
536
finish_callback=finish_callback,
537
)
538
539
lf.sink_parquet(partitioning, mkdir=True)
540
541
542
@pytest.mark.write_disk
543
def test_parquet_preserve_order_within_partition_23376(tmp_path: Path) -> None:
544
ll = list(range(20))
545
df = pl.DataFrame({"a": ll})
546
df.lazy().sink_parquet(pl.PartitionMaxSize(tmp_path, max_size=1))
547
out = pl.scan_parquet(tmp_path).collect().to_series().to_list()
548
assert ll == out
549
550