Path: blob/main/py-polars/tests/unit/io/test_partition.py
8430 views
from __future__ import annotations12import io3from pathlib import Path4from typing import TYPE_CHECKING, Any, TypedDict56import pytest7from hypothesis import example, given89import polars as pl10from polars.exceptions import InvalidOperationError11from polars.testing import assert_frame_equal, assert_series_equal12from polars.testing.parametric.strategies import dataframes13from tests.unit.io.conftest import format_file_uri1415if TYPE_CHECKING:16from polars._typing import EngineType17from polars.io.partition import FileProviderArgs181920class IOType(TypedDict):21"""A type of IO."""2223ext: str24scan: Any25sink: Any262728io_types: list[IOType] = [29{"ext": "csv", "scan": pl.scan_csv, "sink": pl.LazyFrame.sink_csv},30{"ext": "jsonl", "scan": pl.scan_ndjson, "sink": pl.LazyFrame.sink_ndjson},31{"ext": "parquet", "scan": pl.scan_parquet, "sink": pl.LazyFrame.sink_parquet},32{"ext": "ipc", "scan": pl.scan_ipc, "sink": pl.LazyFrame.sink_ipc},33]3435engines: list[EngineType] = [36"streaming",37"in-memory",38]394041def test_partition_by_api() -> None:42with pytest.raises(43ValueError,44match=r"at least one of \('key', 'max_rows_per_file', 'approximate_bytes_per_file'\) must be specified for PartitionBy",45):46pl.PartitionBy("")4748error_cx = pytest.raises(49ValueError, match="cannot use 'include_key' without specifying 'key'"50)5152with error_cx:53pl.PartitionBy("", include_key=True, max_rows_per_file=1)5455with error_cx:56pl.PartitionBy("", include_key=False, max_rows_per_file=1)5758assert (59pl.PartitionBy("", key="key")._pl_partition_by.approximate_bytes_per_file60== 4_294_967_29561)6263# If `max_rows_per_file` was given then `approximate_bytes_per_file` should64# default to disabled (u64::MAX).65assert (66pl.PartitionBy(67"", max_rows_per_file=168)._pl_partition_by.approximate_bytes_per_file69== (1 << 64) - 170)7172assert (73pl.PartitionBy(74"", key="key", max_rows_per_file=175)._pl_partition_by.approximate_bytes_per_file76== (1 << 64) - 177)7879assert (80pl.PartitionBy(81"", max_rows_per_file=1, approximate_bytes_per_file=102482)._pl_partition_by.approximate_bytes_per_file83== 102484)858687@pytest.mark.parametrize("io_type", io_types)88@pytest.mark.parametrize("engine", engines)89@pytest.mark.parametrize("length", [0, 1, 4, 5, 6, 7])90@pytest.mark.parametrize("max_size", [1, 2, 3])91@pytest.mark.write_disk92def test_max_size_partition(93tmp_path: Path,94io_type: IOType,95engine: EngineType,96length: int,97max_size: int,98) -> None:99lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()100101(io_type["sink"])(102lf,103pl.PartitionBy(tmp_path, max_rows_per_file=max_size),104engine=engine,105# We need to sync here because platforms do not guarantee that a close on106# one thread is immediately visible on another thread.107#108# "Multithreaded processes and close()"109# https://man7.org/linux/man-pages/man2/close.2.html110sync_on_close="data",111)112113i = 0114while length > 0:115assert (io_type["scan"])(tmp_path / f"{i:08}.{io_type['ext']}").select(116pl.len()117).collect()[0, 0] == min(max_size, length)118119length -= max_size120i += 1121122123def test_partition_by_max_rows_per_file() -> None:124files = {}125126def file_path_provider(args: FileProviderArgs) -> Any:127f = io.BytesIO()128files[args.index_in_partition] = f129return f130131df = pl.select(x=pl.int_range(0, 100))132df.lazy().sink_parquet(133pl.PartitionBy("", file_path_provider=file_path_provider, max_rows_per_file=10)134)135136for f in files.values():137f.seek(0)138139assert_frame_equal(140pl.scan_parquet([files[i] for i in range(len(files))]).collect(), # type: ignore[arg-type]141df,142)143144for f in files.values():145f.seek(0)146147assert [148pl.scan_parquet(files[i]).select(pl.len()).collect().item()149for i in range(len(files))150] == [10, 10, 10, 10, 10, 10, 10, 10, 10, 10]151152153@pytest.mark.parametrize("io_type", io_types)154@pytest.mark.parametrize("engine", engines)155def test_max_size_partition_lambda(156tmp_path: Path, io_type: IOType, engine: EngineType157) -> None:158length = 17159max_size = 3160lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()161162(io_type["sink"])(163lf,164pl.PartitionBy(165tmp_path,166file_path_provider=lambda args: (167tmp_path / f"abc-{args.index_in_partition:08}.{io_type['ext']}"168),169max_rows_per_file=max_size,170),171engine=engine,172# We need to sync here because platforms do not guarantee that a close on173# one thread is immediately visible on another thread.174#175# "Multithreaded processes and close()"176# https://man7.org/linux/man-pages/man2/close.2.html177sync_on_close="data",178)179180i = 0181while length > 0:182assert (io_type["scan"])(tmp_path / f"abc-{i:08}.{io_type['ext']}").select(183pl.len()184).collect()[0, 0] == min(max_size, length)185186length -= max_size187i += 1188189190@pytest.mark.parametrize("io_type", io_types)191@pytest.mark.parametrize("engine", engines)192@pytest.mark.write_disk193def test_partition_by_key(194tmp_path: Path,195io_type: IOType,196engine: EngineType,197) -> None:198lf = pl.Series("a", [i % 4 for i in range(7)], pl.Int64).to_frame().lazy()199200(io_type["sink"])(201lf,202pl.PartitionBy(203tmp_path,204file_path_provider=lambda args: (205f"{args.partition_keys.item()}.{io_type['ext']}"206),207key="a",208),209engine=engine,210# We need to sync here because platforms do not guarantee that a close on211# one thread is immediately visible on another thread.212#213# "Multithreaded processes and close()"214# https://man7.org/linux/man-pages/man2/close.2.html215sync_on_close="data",216)217218assert_series_equal(219(io_type["scan"])(tmp_path / f"0.{io_type['ext']}").collect().to_series(),220pl.Series("a", [0, 0], pl.Int64),221)222assert_series_equal(223(io_type["scan"])(tmp_path / f"1.{io_type['ext']}").collect().to_series(),224pl.Series("a", [1, 1], pl.Int64),225)226assert_series_equal(227(io_type["scan"])(tmp_path / f"2.{io_type['ext']}").collect().to_series(),228pl.Series("a", [2, 2], pl.Int64),229)230assert_series_equal(231(io_type["scan"])(tmp_path / f"3.{io_type['ext']}").collect().to_series(),232pl.Series("a", [3], pl.Int64),233)234235scan_flags = (236{"schema": pl.Schema({"a": pl.String()})} if io_type["ext"] == "csv" else {}237)238239# Change the datatype.240(io_type["sink"])(241lf,242pl.PartitionBy(243tmp_path,244file_path_provider=lambda args: (245f"{args.partition_keys.item()}.{io_type['ext']}"246),247key=pl.col.a.cast(pl.String()),248),249engine=engine,250sync_on_close="data",251)252253assert_series_equal(254(io_type["scan"])(tmp_path / f"0.{io_type['ext']}", **scan_flags)255.collect()256.to_series(),257pl.Series("a", ["0", "0"], pl.String),258)259assert_series_equal(260(io_type["scan"])(tmp_path / f"1.{io_type['ext']}", **scan_flags)261.collect()262.to_series(),263pl.Series("a", ["1", "1"], pl.String),264)265assert_series_equal(266(io_type["scan"])(tmp_path / f"2.{io_type['ext']}", **scan_flags)267.collect()268.to_series(),269pl.Series("a", ["2", "2"], pl.String),270)271assert_series_equal(272(io_type["scan"])(tmp_path / f"3.{io_type['ext']}", **scan_flags)273.collect()274.to_series(),275pl.Series("a", ["3"], pl.String),276)277278279# We only deal with self-describing formats280@pytest.mark.parametrize("io_type", [io_types[2], io_types[3]])281@example(df=pl.DataFrame({"a": [0.0, -0.0]}, schema={"a": pl.Float16}))282@given(283df=dataframes(284min_cols=1,285min_size=1,286excluded_dtypes=[287pl.Decimal, # Bug see: https://github.com/pola-rs/polars/issues/21684288pl.Duration, # Bug see: https://github.com/pola-rs/polars/issues/21964289pl.Categorical, # We cannot ensure the string cache is properly held.290# Generate invalid UTF-8291pl.Binary,292pl.Struct,293pl.Array,294pl.List,295pl.Extension, # Can't be cast to string296],297)298)299def test_partition_by_key_parametric(300io_type: IOType,301df: pl.DataFrame,302) -> None:303col1 = df.columns[0]304305output_files = []306307def file_path_provider(args: FileProviderArgs) -> io.BytesIO:308f = io.BytesIO()309output_files.append(f)310return f311312(io_type["sink"])(313df.lazy(),314pl.PartitionBy(315"",316file_path_provider=file_path_provider,317key=col1,318),319# We need to sync here because platforms do not guarantee that a close on320# one thread is immediately visible on another thread.321#322# "Multithreaded processes and close()"323# https://man7.org/linux/man-pages/man2/close.2.html324sync_on_close="data",325)326327for f in output_files:328f.seek(0)329330assert_frame_equal(331io_type["scan"](output_files).collect(),332df,333check_row_order=False,334)335336337def test_partition_by_file_naming_preserves_order(tmp_path: Path) -> None:338df = pl.DataFrame({"x": range(100)})339df.lazy().sink_parquet(pl.PartitionBy(tmp_path, max_rows_per_file=1))340341output_files = sorted(tmp_path.iterdir())342assert len(output_files) == 100343344assert_frame_equal(pl.scan_parquet(output_files).collect(), df)345346347@pytest.mark.parametrize(("io_type"), io_types)348@pytest.mark.parametrize("engine", engines)349def test_partition_to_memory(io_type: IOType, engine: EngineType) -> None:350df = pl.DataFrame(351{352"a": [5, 10, 1996],353}354)355356output_files = {}357358def file_path_provider(args: FileProviderArgs) -> io.BytesIO:359f = io.BytesIO()360output_files[args.index_in_partition] = f361return f362363io_type["sink"](364df.lazy(),365pl.PartitionBy("", file_path_provider=file_path_provider, max_rows_per_file=1),366engine=engine,367)368369assert len(output_files) == df.height370371for f in output_files.values():372f.seek(0)373374assert_frame_equal(375io_type["scan"](output_files[0]).collect(), pl.DataFrame({"a": [5]})376)377assert_frame_equal(378io_type["scan"](output_files[1]).collect(), pl.DataFrame({"a": [10]})379)380assert_frame_equal(381io_type["scan"](output_files[2]).collect(), pl.DataFrame({"a": [1996]})382)383384385@pytest.mark.write_disk386def test_partition_key_order_22645(tmp_path: Path) -> None:387pl.LazyFrame({"a": [1]}).sink_parquet(388pl.PartitionBy(389tmp_path,390key=[pl.col.a.alias("b"), (pl.col.a + 42).alias("c")],391),392)393394assert_frame_equal(395pl.scan_parquet(tmp_path / "b=1" / "c=43").collect(),396pl.DataFrame({"a": [1], "b": [1], "c": [43]}),397)398399400@pytest.mark.write_disk401def test_parquet_preserve_order_within_partition_23376(tmp_path: Path) -> None:402ll = list(range(20))403df = pl.DataFrame({"a": ll})404df.lazy().sink_parquet(pl.PartitionBy(tmp_path, max_rows_per_file=1))405out = pl.scan_parquet(tmp_path).collect().to_series().to_list()406assert ll == out407408409@pytest.mark.write_disk410def test_file_path_cb_new_cloud_path(tmp_path: Path) -> None:411i = 0412413def new_path(_: Any) -> str:414nonlocal i415p = format_file_uri(f"{tmp_path}/pms-{i:08}.parquet")416i += 1417return p418419df = pl.DataFrame({"a": [1, 2]})420df.lazy().sink_csv(421pl.PartitionBy(422"s3://bucket-x", file_path_provider=new_path, max_rows_per_file=1423)424)425426assert_frame_equal(pl.scan_csv(tmp_path).collect(), df, check_row_order=False)427428429@pytest.mark.write_disk430def test_partition_empty_string_24545(tmp_path: Path) -> None:431df = pl.DataFrame(432{433"a": ["", None, "abc", "xyz"],434"b": [1, 2, 3, 4],435}436)437438df.write_parquet(tmp_path, partition_by="a")439440assert_frame_equal(pl.read_parquet(tmp_path), df)441442443@pytest.mark.write_disk444@pytest.mark.parametrize("dtype", [pl.Int64(), pl.Date(), pl.Datetime()])445def test_partition_empty_dtype_24545(tmp_path: Path, dtype: pl.DataType) -> None:446df = pl.DataFrame({"b": [1, 2, 3, 4]}).with_columns(447a=pl.col.b.cast(dtype),448)449450df.write_parquet(tmp_path, partition_by="a")451extra = pl.select(b=pl.lit(0, pl.Int64), a=pl.lit(None, dtype))452extra.write_parquet(Path(tmp_path / "a=" / "000.parquet"), mkdir=True)453454assert_frame_equal(pl.read_parquet(tmp_path), pl.concat([extra, df]))455456457@pytest.mark.slow458@pytest.mark.write_disk459def test_partition_approximate_size(tmp_path: Path) -> None:460n_rows = 500_000461df = pl.select(a=pl.repeat(0, n_rows), b=pl.int_range(0, n_rows))462463root = tmp_path464df.lazy().sink_parquet(465pl.PartitionBy(root, approximate_bytes_per_file=200000),466row_group_size=10_000,467)468469files = sorted(root.iterdir())470471assert len(files) == 30472473assert [474pl.scan_parquet(x).select(pl.len()).collect().item() for x in files475] == 29 * [16667] + [16657]476477assert_frame_equal(pl.scan_parquet(root).collect(), df)478479480def test_sink_partitioned_forbid_non_elementwise_key_expr_25535() -> None:481with pytest.raises(482InvalidOperationError,483match="cannot use non-elementwise expressions for PartitionBy keys",484):485pl.LazyFrame({"a": 1}).sink_parquet(pl.PartitionBy("", key=pl.col("a").sum()))486487488@pytest.mark.write_disk489@pytest.mark.parametrize(490("scan_func", "sink_func"),491[492(pl.scan_parquet, pl.LazyFrame.sink_parquet),493(pl.scan_ipc, pl.LazyFrame.sink_ipc),494],495)496def test_sink_partitioned_no_columns_in_file_25535(497tmp_path: Path, scan_func: Any, sink_func: Any498) -> None:499df = pl.DataFrame({"x": [1, 1, 1, 1, 1]})500partitioned_root = tmp_path / "partitioned"501sink_func(502df.lazy(),503pl.PartitionBy(partitioned_root, key="x", include_key=False),504)505506assert_frame_equal(scan_func(partitioned_root).collect(), df)507508max_size_root = tmp_path / "max-size"509sink_func(510pl.LazyFrame(height=10),511pl.PartitionBy(max_size_root, max_rows_per_file=2),512)513514assert sum(1 for _ in max_size_root.iterdir()) == 5515assert scan_func(max_size_root).collect().shape == (10, 0)516assert scan_func(max_size_root).select(pl.len()).collect().item() == 10517518519def test_partition_by_scalar_expr_26294(tmp_path: Path) -> None:520pl.LazyFrame(height=5).sink_parquet(521pl.PartitionBy(tmp_path, key=pl.lit(1, dtype=pl.Int64))522)523524assert_frame_equal(525pl.scan_parquet(tmp_path).collect(),526pl.DataFrame({"literal": [1, 1, 1, 1, 1]}),527)528529530def test_partition_by_diff_expr_26370(tmp_path: Path) -> None:531q = pl.LazyFrame({"x": [1, 2]}).cast(pl.Decimal(precision=1))532q = q.with_columns(pl.col("x").diff().alias("y"), pl.lit(1).alias("z"))533534q.sink_parquet(pl.PartitionBy(tmp_path, key="z"))535536assert_frame_equal(pl.scan_parquet(tmp_path).collect(), q.collect())537538539