Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_partition.py
8430 views
1
from __future__ import annotations
2
3
import io
4
from pathlib import Path
5
from typing import TYPE_CHECKING, Any, TypedDict
6
7
import pytest
8
from hypothesis import example, given
9
10
import polars as pl
11
from polars.exceptions import InvalidOperationError
12
from polars.testing import assert_frame_equal, assert_series_equal
13
from polars.testing.parametric.strategies import dataframes
14
from tests.unit.io.conftest import format_file_uri
15
16
if TYPE_CHECKING:
17
from polars._typing import EngineType
18
from polars.io.partition import FileProviderArgs
19
20
21
class IOType(TypedDict):
22
"""A type of IO."""
23
24
ext: str
25
scan: Any
26
sink: Any
27
28
29
io_types: list[IOType] = [
30
{"ext": "csv", "scan": pl.scan_csv, "sink": pl.LazyFrame.sink_csv},
31
{"ext": "jsonl", "scan": pl.scan_ndjson, "sink": pl.LazyFrame.sink_ndjson},
32
{"ext": "parquet", "scan": pl.scan_parquet, "sink": pl.LazyFrame.sink_parquet},
33
{"ext": "ipc", "scan": pl.scan_ipc, "sink": pl.LazyFrame.sink_ipc},
34
]
35
36
engines: list[EngineType] = [
37
"streaming",
38
"in-memory",
39
]
40
41
42
def test_partition_by_api() -> None:
43
with pytest.raises(
44
ValueError,
45
match=r"at least one of \('key', 'max_rows_per_file', 'approximate_bytes_per_file'\) must be specified for PartitionBy",
46
):
47
pl.PartitionBy("")
48
49
error_cx = pytest.raises(
50
ValueError, match="cannot use 'include_key' without specifying 'key'"
51
)
52
53
with error_cx:
54
pl.PartitionBy("", include_key=True, max_rows_per_file=1)
55
56
with error_cx:
57
pl.PartitionBy("", include_key=False, max_rows_per_file=1)
58
59
assert (
60
pl.PartitionBy("", key="key")._pl_partition_by.approximate_bytes_per_file
61
== 4_294_967_295
62
)
63
64
# If `max_rows_per_file` was given then `approximate_bytes_per_file` should
65
# default to disabled (u64::MAX).
66
assert (
67
pl.PartitionBy(
68
"", max_rows_per_file=1
69
)._pl_partition_by.approximate_bytes_per_file
70
== (1 << 64) - 1
71
)
72
73
assert (
74
pl.PartitionBy(
75
"", key="key", max_rows_per_file=1
76
)._pl_partition_by.approximate_bytes_per_file
77
== (1 << 64) - 1
78
)
79
80
assert (
81
pl.PartitionBy(
82
"", max_rows_per_file=1, approximate_bytes_per_file=1024
83
)._pl_partition_by.approximate_bytes_per_file
84
== 1024
85
)
86
87
88
@pytest.mark.parametrize("io_type", io_types)
89
@pytest.mark.parametrize("engine", engines)
90
@pytest.mark.parametrize("length", [0, 1, 4, 5, 6, 7])
91
@pytest.mark.parametrize("max_size", [1, 2, 3])
92
@pytest.mark.write_disk
93
def test_max_size_partition(
94
tmp_path: Path,
95
io_type: IOType,
96
engine: EngineType,
97
length: int,
98
max_size: int,
99
) -> None:
100
lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()
101
102
(io_type["sink"])(
103
lf,
104
pl.PartitionBy(tmp_path, max_rows_per_file=max_size),
105
engine=engine,
106
# We need to sync here because platforms do not guarantee that a close on
107
# one thread is immediately visible on another thread.
108
#
109
# "Multithreaded processes and close()"
110
# https://man7.org/linux/man-pages/man2/close.2.html
111
sync_on_close="data",
112
)
113
114
i = 0
115
while length > 0:
116
assert (io_type["scan"])(tmp_path / f"{i:08}.{io_type['ext']}").select(
117
pl.len()
118
).collect()[0, 0] == min(max_size, length)
119
120
length -= max_size
121
i += 1
122
123
124
def test_partition_by_max_rows_per_file() -> None:
125
files = {}
126
127
def file_path_provider(args: FileProviderArgs) -> Any:
128
f = io.BytesIO()
129
files[args.index_in_partition] = f
130
return f
131
132
df = pl.select(x=pl.int_range(0, 100))
133
df.lazy().sink_parquet(
134
pl.PartitionBy("", file_path_provider=file_path_provider, max_rows_per_file=10)
135
)
136
137
for f in files.values():
138
f.seek(0)
139
140
assert_frame_equal(
141
pl.scan_parquet([files[i] for i in range(len(files))]).collect(), # type: ignore[arg-type]
142
df,
143
)
144
145
for f in files.values():
146
f.seek(0)
147
148
assert [
149
pl.scan_parquet(files[i]).select(pl.len()).collect().item()
150
for i in range(len(files))
151
] == [10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
152
153
154
@pytest.mark.parametrize("io_type", io_types)
155
@pytest.mark.parametrize("engine", engines)
156
def test_max_size_partition_lambda(
157
tmp_path: Path, io_type: IOType, engine: EngineType
158
) -> None:
159
length = 17
160
max_size = 3
161
lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()
162
163
(io_type["sink"])(
164
lf,
165
pl.PartitionBy(
166
tmp_path,
167
file_path_provider=lambda args: (
168
tmp_path / f"abc-{args.index_in_partition:08}.{io_type['ext']}"
169
),
170
max_rows_per_file=max_size,
171
),
172
engine=engine,
173
# We need to sync here because platforms do not guarantee that a close on
174
# one thread is immediately visible on another thread.
175
#
176
# "Multithreaded processes and close()"
177
# https://man7.org/linux/man-pages/man2/close.2.html
178
sync_on_close="data",
179
)
180
181
i = 0
182
while length > 0:
183
assert (io_type["scan"])(tmp_path / f"abc-{i:08}.{io_type['ext']}").select(
184
pl.len()
185
).collect()[0, 0] == min(max_size, length)
186
187
length -= max_size
188
i += 1
189
190
191
@pytest.mark.parametrize("io_type", io_types)
192
@pytest.mark.parametrize("engine", engines)
193
@pytest.mark.write_disk
194
def test_partition_by_key(
195
tmp_path: Path,
196
io_type: IOType,
197
engine: EngineType,
198
) -> None:
199
lf = pl.Series("a", [i % 4 for i in range(7)], pl.Int64).to_frame().lazy()
200
201
(io_type["sink"])(
202
lf,
203
pl.PartitionBy(
204
tmp_path,
205
file_path_provider=lambda args: (
206
f"{args.partition_keys.item()}.{io_type['ext']}"
207
),
208
key="a",
209
),
210
engine=engine,
211
# We need to sync here because platforms do not guarantee that a close on
212
# one thread is immediately visible on another thread.
213
#
214
# "Multithreaded processes and close()"
215
# https://man7.org/linux/man-pages/man2/close.2.html
216
sync_on_close="data",
217
)
218
219
assert_series_equal(
220
(io_type["scan"])(tmp_path / f"0.{io_type['ext']}").collect().to_series(),
221
pl.Series("a", [0, 0], pl.Int64),
222
)
223
assert_series_equal(
224
(io_type["scan"])(tmp_path / f"1.{io_type['ext']}").collect().to_series(),
225
pl.Series("a", [1, 1], pl.Int64),
226
)
227
assert_series_equal(
228
(io_type["scan"])(tmp_path / f"2.{io_type['ext']}").collect().to_series(),
229
pl.Series("a", [2, 2], pl.Int64),
230
)
231
assert_series_equal(
232
(io_type["scan"])(tmp_path / f"3.{io_type['ext']}").collect().to_series(),
233
pl.Series("a", [3], pl.Int64),
234
)
235
236
scan_flags = (
237
{"schema": pl.Schema({"a": pl.String()})} if io_type["ext"] == "csv" else {}
238
)
239
240
# Change the datatype.
241
(io_type["sink"])(
242
lf,
243
pl.PartitionBy(
244
tmp_path,
245
file_path_provider=lambda args: (
246
f"{args.partition_keys.item()}.{io_type['ext']}"
247
),
248
key=pl.col.a.cast(pl.String()),
249
),
250
engine=engine,
251
sync_on_close="data",
252
)
253
254
assert_series_equal(
255
(io_type["scan"])(tmp_path / f"0.{io_type['ext']}", **scan_flags)
256
.collect()
257
.to_series(),
258
pl.Series("a", ["0", "0"], pl.String),
259
)
260
assert_series_equal(
261
(io_type["scan"])(tmp_path / f"1.{io_type['ext']}", **scan_flags)
262
.collect()
263
.to_series(),
264
pl.Series("a", ["1", "1"], pl.String),
265
)
266
assert_series_equal(
267
(io_type["scan"])(tmp_path / f"2.{io_type['ext']}", **scan_flags)
268
.collect()
269
.to_series(),
270
pl.Series("a", ["2", "2"], pl.String),
271
)
272
assert_series_equal(
273
(io_type["scan"])(tmp_path / f"3.{io_type['ext']}", **scan_flags)
274
.collect()
275
.to_series(),
276
pl.Series("a", ["3"], pl.String),
277
)
278
279
280
# We only deal with self-describing formats
281
@pytest.mark.parametrize("io_type", [io_types[2], io_types[3]])
282
@example(df=pl.DataFrame({"a": [0.0, -0.0]}, schema={"a": pl.Float16}))
283
@given(
284
df=dataframes(
285
min_cols=1,
286
min_size=1,
287
excluded_dtypes=[
288
pl.Decimal, # Bug see: https://github.com/pola-rs/polars/issues/21684
289
pl.Duration, # Bug see: https://github.com/pola-rs/polars/issues/21964
290
pl.Categorical, # We cannot ensure the string cache is properly held.
291
# Generate invalid UTF-8
292
pl.Binary,
293
pl.Struct,
294
pl.Array,
295
pl.List,
296
pl.Extension, # Can't be cast to string
297
],
298
)
299
)
300
def test_partition_by_key_parametric(
301
io_type: IOType,
302
df: pl.DataFrame,
303
) -> None:
304
col1 = df.columns[0]
305
306
output_files = []
307
308
def file_path_provider(args: FileProviderArgs) -> io.BytesIO:
309
f = io.BytesIO()
310
output_files.append(f)
311
return f
312
313
(io_type["sink"])(
314
df.lazy(),
315
pl.PartitionBy(
316
"",
317
file_path_provider=file_path_provider,
318
key=col1,
319
),
320
# We need to sync here because platforms do not guarantee that a close on
321
# one thread is immediately visible on another thread.
322
#
323
# "Multithreaded processes and close()"
324
# https://man7.org/linux/man-pages/man2/close.2.html
325
sync_on_close="data",
326
)
327
328
for f in output_files:
329
f.seek(0)
330
331
assert_frame_equal(
332
io_type["scan"](output_files).collect(),
333
df,
334
check_row_order=False,
335
)
336
337
338
def test_partition_by_file_naming_preserves_order(tmp_path: Path) -> None:
339
df = pl.DataFrame({"x": range(100)})
340
df.lazy().sink_parquet(pl.PartitionBy(tmp_path, max_rows_per_file=1))
341
342
output_files = sorted(tmp_path.iterdir())
343
assert len(output_files) == 100
344
345
assert_frame_equal(pl.scan_parquet(output_files).collect(), df)
346
347
348
@pytest.mark.parametrize(("io_type"), io_types)
349
@pytest.mark.parametrize("engine", engines)
350
def test_partition_to_memory(io_type: IOType, engine: EngineType) -> None:
351
df = pl.DataFrame(
352
{
353
"a": [5, 10, 1996],
354
}
355
)
356
357
output_files = {}
358
359
def file_path_provider(args: FileProviderArgs) -> io.BytesIO:
360
f = io.BytesIO()
361
output_files[args.index_in_partition] = f
362
return f
363
364
io_type["sink"](
365
df.lazy(),
366
pl.PartitionBy("", file_path_provider=file_path_provider, max_rows_per_file=1),
367
engine=engine,
368
)
369
370
assert len(output_files) == df.height
371
372
for f in output_files.values():
373
f.seek(0)
374
375
assert_frame_equal(
376
io_type["scan"](output_files[0]).collect(), pl.DataFrame({"a": [5]})
377
)
378
assert_frame_equal(
379
io_type["scan"](output_files[1]).collect(), pl.DataFrame({"a": [10]})
380
)
381
assert_frame_equal(
382
io_type["scan"](output_files[2]).collect(), pl.DataFrame({"a": [1996]})
383
)
384
385
386
@pytest.mark.write_disk
387
def test_partition_key_order_22645(tmp_path: Path) -> None:
388
pl.LazyFrame({"a": [1]}).sink_parquet(
389
pl.PartitionBy(
390
tmp_path,
391
key=[pl.col.a.alias("b"), (pl.col.a + 42).alias("c")],
392
),
393
)
394
395
assert_frame_equal(
396
pl.scan_parquet(tmp_path / "b=1" / "c=43").collect(),
397
pl.DataFrame({"a": [1], "b": [1], "c": [43]}),
398
)
399
400
401
@pytest.mark.write_disk
402
def test_parquet_preserve_order_within_partition_23376(tmp_path: Path) -> None:
403
ll = list(range(20))
404
df = pl.DataFrame({"a": ll})
405
df.lazy().sink_parquet(pl.PartitionBy(tmp_path, max_rows_per_file=1))
406
out = pl.scan_parquet(tmp_path).collect().to_series().to_list()
407
assert ll == out
408
409
410
@pytest.mark.write_disk
411
def test_file_path_cb_new_cloud_path(tmp_path: Path) -> None:
412
i = 0
413
414
def new_path(_: Any) -> str:
415
nonlocal i
416
p = format_file_uri(f"{tmp_path}/pms-{i:08}.parquet")
417
i += 1
418
return p
419
420
df = pl.DataFrame({"a": [1, 2]})
421
df.lazy().sink_csv(
422
pl.PartitionBy(
423
"s3://bucket-x", file_path_provider=new_path, max_rows_per_file=1
424
)
425
)
426
427
assert_frame_equal(pl.scan_csv(tmp_path).collect(), df, check_row_order=False)
428
429
430
@pytest.mark.write_disk
431
def test_partition_empty_string_24545(tmp_path: Path) -> None:
432
df = pl.DataFrame(
433
{
434
"a": ["", None, "abc", "xyz"],
435
"b": [1, 2, 3, 4],
436
}
437
)
438
439
df.write_parquet(tmp_path, partition_by="a")
440
441
assert_frame_equal(pl.read_parquet(tmp_path), df)
442
443
444
@pytest.mark.write_disk
445
@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Date(), pl.Datetime()])
446
def test_partition_empty_dtype_24545(tmp_path: Path, dtype: pl.DataType) -> None:
447
df = pl.DataFrame({"b": [1, 2, 3, 4]}).with_columns(
448
a=pl.col.b.cast(dtype),
449
)
450
451
df.write_parquet(tmp_path, partition_by="a")
452
extra = pl.select(b=pl.lit(0, pl.Int64), a=pl.lit(None, dtype))
453
extra.write_parquet(Path(tmp_path / "a=" / "000.parquet"), mkdir=True)
454
455
assert_frame_equal(pl.read_parquet(tmp_path), pl.concat([extra, df]))
456
457
458
@pytest.mark.slow
459
@pytest.mark.write_disk
460
def test_partition_approximate_size(tmp_path: Path) -> None:
461
n_rows = 500_000
462
df = pl.select(a=pl.repeat(0, n_rows), b=pl.int_range(0, n_rows))
463
464
root = tmp_path
465
df.lazy().sink_parquet(
466
pl.PartitionBy(root, approximate_bytes_per_file=200000),
467
row_group_size=10_000,
468
)
469
470
files = sorted(root.iterdir())
471
472
assert len(files) == 30
473
474
assert [
475
pl.scan_parquet(x).select(pl.len()).collect().item() for x in files
476
] == 29 * [16667] + [16657]
477
478
assert_frame_equal(pl.scan_parquet(root).collect(), df)
479
480
481
def test_sink_partitioned_forbid_non_elementwise_key_expr_25535() -> None:
482
with pytest.raises(
483
InvalidOperationError,
484
match="cannot use non-elementwise expressions for PartitionBy keys",
485
):
486
pl.LazyFrame({"a": 1}).sink_parquet(pl.PartitionBy("", key=pl.col("a").sum()))
487
488
489
@pytest.mark.write_disk
490
@pytest.mark.parametrize(
491
("scan_func", "sink_func"),
492
[
493
(pl.scan_parquet, pl.LazyFrame.sink_parquet),
494
(pl.scan_ipc, pl.LazyFrame.sink_ipc),
495
],
496
)
497
def test_sink_partitioned_no_columns_in_file_25535(
498
tmp_path: Path, scan_func: Any, sink_func: Any
499
) -> None:
500
df = pl.DataFrame({"x": [1, 1, 1, 1, 1]})
501
partitioned_root = tmp_path / "partitioned"
502
sink_func(
503
df.lazy(),
504
pl.PartitionBy(partitioned_root, key="x", include_key=False),
505
)
506
507
assert_frame_equal(scan_func(partitioned_root).collect(), df)
508
509
max_size_root = tmp_path / "max-size"
510
sink_func(
511
pl.LazyFrame(height=10),
512
pl.PartitionBy(max_size_root, max_rows_per_file=2),
513
)
514
515
assert sum(1 for _ in max_size_root.iterdir()) == 5
516
assert scan_func(max_size_root).collect().shape == (10, 0)
517
assert scan_func(max_size_root).select(pl.len()).collect().item() == 10
518
519
520
def test_partition_by_scalar_expr_26294(tmp_path: Path) -> None:
521
pl.LazyFrame(height=5).sink_parquet(
522
pl.PartitionBy(tmp_path, key=pl.lit(1, dtype=pl.Int64))
523
)
524
525
assert_frame_equal(
526
pl.scan_parquet(tmp_path).collect(),
527
pl.DataFrame({"literal": [1, 1, 1, 1, 1]}),
528
)
529
530
531
def test_partition_by_diff_expr_26370(tmp_path: Path) -> None:
532
q = pl.LazyFrame({"x": [1, 2]}).cast(pl.Decimal(precision=1))
533
q = q.with_columns(pl.col("x").diff().alias("y"), pl.lit(1).alias("z"))
534
535
q.sink_parquet(pl.PartitionBy(tmp_path, key="z"))
536
537
assert_frame_equal(pl.scan_parquet(tmp_path).collect(), q.collect())
538
539