Path: blob/main/py-polars/tests/unit/io/test_partition.py
6939 views
from __future__ import annotations12import io3from typing import TYPE_CHECKING, Any, TypedDict45import pytest6from hypothesis import given78import polars as pl9from polars.io.partition import (10PartitionByKey,11PartitionMaxSize,12PartitionParted,13)14from polars.testing import assert_frame_equal, assert_series_equal15from polars.testing.parametric.strategies import dataframes1617if TYPE_CHECKING:18from pathlib import Path1920from polars._typing import EngineType21from polars.io.partition import BasePartitionContext, KeyedPartitionContext222324class IOType(TypedDict):25"""A type of IO."""2627ext: str28scan: Any29sink: Any303132io_types: list[IOType] = [33{"ext": "csv", "scan": pl.scan_csv, "sink": pl.LazyFrame.sink_csv},34{"ext": "jsonl", "scan": pl.scan_ndjson, "sink": pl.LazyFrame.sink_ndjson},35{"ext": "parquet", "scan": pl.scan_parquet, "sink": pl.LazyFrame.sink_parquet},36{"ext": "ipc", "scan": pl.scan_ipc, "sink": pl.LazyFrame.sink_ipc},37]3839engines: list[EngineType] = [40"streaming",41"in-memory",42]434445@pytest.mark.parametrize("io_type", io_types)46@pytest.mark.parametrize("engine", engines)47@pytest.mark.parametrize("length", [0, 1, 4, 5, 6, 7])48@pytest.mark.parametrize("max_size", [1, 2, 3])49@pytest.mark.write_disk50def test_max_size_partition(51tmp_path: Path,52io_type: IOType,53engine: EngineType,54length: int,55max_size: int,56) -> None:57lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()5859(io_type["sink"])(60lf,61PartitionMaxSize(tmp_path, max_size=max_size),62engine=engine,63# We need to sync here because platforms do not guarantee that a close on64# one thread is immediately visible on another thread.65#66# "Multithreaded processes and close()"67# https://man7.org/linux/man-pages/man2/close.2.html68sync_on_close="data",69)7071i = 072while length > 0:73assert (io_type["scan"])(tmp_path / f"{i:08x}.{io_type['ext']}").select(74pl.len()75).collect()[0, 0] == min(max_size, length)7677length -= max_size78i += 1798081@pytest.mark.parametrize("io_type", io_types)82@pytest.mark.parametrize("engine", engines)83def test_max_size_partition_lambda(84tmp_path: Path, io_type: IOType, engine: EngineType85) -> None:86length = 1787max_size = 388lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()8990(io_type["sink"])(91lf,92PartitionMaxSize(93tmp_path,94file_path=lambda ctx: ctx.file_path.with_name("abc-" + ctx.file_path.name),95max_size=max_size,96),97engine=engine,98# We need to sync here because platforms do not guarantee that a close on99# one thread is immediately visible on another thread.100#101# "Multithreaded processes and close()"102# https://man7.org/linux/man-pages/man2/close.2.html103sync_on_close="data",104)105106i = 0107while length > 0:108assert (io_type["scan"])(tmp_path / f"abc-{i:08x}.{io_type['ext']}").select(109pl.len()110).collect()[0, 0] == min(max_size, length)111112length -= max_size113i += 1114115116@pytest.mark.parametrize("io_type", io_types)117@pytest.mark.parametrize("engine", engines)118@pytest.mark.write_disk119def test_partition_by_key(120tmp_path: Path,121io_type: IOType,122engine: EngineType,123) -> None:124lf = pl.Series("a", [i % 4 for i in range(7)], pl.Int64).to_frame().lazy()125126(io_type["sink"])(127lf,128PartitionByKey(129tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by="a"130),131engine=engine,132# We need to sync here because platforms do not guarantee that a close on133# one thread is immediately visible on another thread.134#135# "Multithreaded processes and close()"136# https://man7.org/linux/man-pages/man2/close.2.html137sync_on_close="data",138)139140assert_series_equal(141(io_type["scan"])(tmp_path / f"0.{io_type['ext']}").collect().to_series(),142pl.Series("a", [0, 0], pl.Int64),143)144assert_series_equal(145(io_type["scan"])(tmp_path / f"1.{io_type['ext']}").collect().to_series(),146pl.Series("a", [1, 1], pl.Int64),147)148assert_series_equal(149(io_type["scan"])(tmp_path / f"2.{io_type['ext']}").collect().to_series(),150pl.Series("a", [2, 2], pl.Int64),151)152assert_series_equal(153(io_type["scan"])(tmp_path / f"3.{io_type['ext']}").collect().to_series(),154pl.Series("a", [3], pl.Int64),155)156157scan_flags = (158{"schema": pl.Schema({"a": pl.String()})} if io_type["ext"] == "csv" else {}159)160161# Change the datatype.162(io_type["sink"])(163lf,164PartitionByKey(165tmp_path,166file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}",167by=pl.col.a.cast(pl.String()),168),169engine=engine,170sync_on_close="data",171)172173assert_series_equal(174(io_type["scan"])(tmp_path / f"0.{io_type['ext']}", **scan_flags)175.collect()176.to_series(),177pl.Series("a", ["0", "0"], pl.String),178)179assert_series_equal(180(io_type["scan"])(tmp_path / f"1.{io_type['ext']}", **scan_flags)181.collect()182.to_series(),183pl.Series("a", ["1", "1"], pl.String),184)185assert_series_equal(186(io_type["scan"])(tmp_path / f"2.{io_type['ext']}", **scan_flags)187.collect()188.to_series(),189pl.Series("a", ["2", "2"], pl.String),190)191assert_series_equal(192(io_type["scan"])(tmp_path / f"3.{io_type['ext']}", **scan_flags)193.collect()194.to_series(),195pl.Series("a", ["3"], pl.String),196)197198199@pytest.mark.parametrize("io_type", io_types)200@pytest.mark.parametrize("engine", engines)201@pytest.mark.write_disk202def test_partition_parted(tmp_path: Path, io_type: IOType, engine: EngineType) -> None:203s = pl.Series("a", [1, 1, 2, 3, 3, 4, 4, 4, 6], pl.Int64)204lf = s.to_frame().lazy()205206(io_type["sink"])(207lf,208PartitionParted(209tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by="a"210),211engine=engine,212# We need to sync here because platforms do not guarantee that a close on213# one thread is immediately visible on another thread.214#215# "Multithreaded processes and close()"216# https://man7.org/linux/man-pages/man2/close.2.html217sync_on_close="data",218)219220rle = s.rle()221222for i, row in enumerate(rle.struct.unnest().rows(named=True)):223assert_series_equal(224(io_type["scan"])(tmp_path / f"{i}.{io_type['ext']}").collect().to_series(),225pl.Series("a", [row["value"]] * row["len"], pl.Int64),226)227228scan_flags = (229{"schema_overrides": pl.Schema({"a_str": pl.String()})}230if io_type["ext"] == "csv"231else {}232)233234# Change the datatype.235(io_type["sink"])(236lf,237PartitionParted(238tmp_path,239file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}",240by=[pl.col.a, pl.col.a.cast(pl.String()).alias("a_str")],241),242engine=engine,243sync_on_close="data",244)245246for i, row in enumerate(rle.struct.unnest().rows(named=True)):247assert_frame_equal(248(io_type["scan"])(249tmp_path / f"{i}.{io_type['ext']}", **scan_flags250).collect(),251pl.DataFrame(252[253pl.Series("a", [row["value"]] * row["len"], pl.Int64),254pl.Series("a_str", [str(row["value"])] * row["len"], pl.String),255]256),257)258259# No include key.260(io_type["sink"])(261lf,262PartitionParted(263tmp_path,264file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}",265by=[pl.col.a.cast(pl.String()).alias("a_str")],266include_key=False,267),268engine=engine,269sync_on_close="data",270)271272for i, row in enumerate(rle.struct.unnest().rows(named=True)):273assert_series_equal(274(io_type["scan"])(tmp_path / f"{i}.{io_type['ext']}").collect().to_series(),275pl.Series("a", [row["value"]] * row["len"], pl.Int64),276)277278279# We only deal with self-describing formats280@pytest.mark.parametrize("io_type", [io_types[2], io_types[3]])281@pytest.mark.parametrize("engine", engines)282@pytest.mark.write_disk283@given(284df=dataframes(285min_cols=1,286excluded_dtypes=[287pl.Decimal, # Bug see: https://github.com/pola-rs/polars/issues/21684288pl.Duration, # Bug see: https://github.com/pola-rs/polars/issues/21964289pl.Categorical, # We cannot ensure the string cache is properly held.290# Generate invalid UTF-8291pl.Binary,292pl.Struct,293pl.Array,294pl.List,295],296)297)298def test_partition_by_key_parametric(299tmp_path_factory: pytest.TempPathFactory,300io_type: IOType,301engine: EngineType,302df: pl.DataFrame,303) -> None:304col1 = df.columns[0]305306tmp_path = tmp_path_factory.mktemp("data")307308dfs = df.partition_by(col1)309(io_type["sink"])(310df.lazy(),311PartitionByKey(312tmp_path, file_path=lambda ctx: f"{ctx.file_idx}.{io_type['ext']}", by=col1313),314engine=engine,315# We need to sync here because platforms do not guarantee that a close on316# one thread is immediately visible on another thread.317#318# "Multithreaded processes and close()"319# https://man7.org/linux/man-pages/man2/close.2.html320sync_on_close="data",321)322323for i, df in enumerate(dfs):324assert_frame_equal(325df,326(io_type["scan"])(327tmp_path / f"{i}.{io_type['ext']}",328).collect(),329)330331332def test_max_size_partition_collect_files(tmp_path: Path) -> None:333length = 17334max_size = 3335lf = pl.Series("a", range(length), pl.Int64).to_frame().lazy()336337io_type = io_types[0]338output_files = []339340def file_path_cb(ctx: BasePartitionContext) -> Path:341print(ctx)342print(ctx.full_path)343output_files.append(ctx.full_path)344print(ctx.file_path)345return ctx.file_path346347(io_type["sink"])(348lf,349PartitionMaxSize(tmp_path, file_path=file_path_cb, max_size=max_size),350engine="streaming",351# We need to sync here because platforms do not guarantee that a close on352# one thread is immediately visible on another thread.353#354# "Multithreaded processes and close()"355# https://man7.org/linux/man-pages/man2/close.2.html356sync_on_close="data",357)358359assert output_files == [tmp_path / f"{i:08x}.{io_type['ext']}" for i in range(6)]360361362@pytest.mark.parametrize(("io_type"), io_types)363@pytest.mark.parametrize("engine", engines)364def test_partition_to_memory(io_type: IOType, engine: EngineType) -> None:365df = pl.DataFrame(366{367"a": [5, 10, 1996],368}369)370371output_files = {}372373def file_path_cb(ctx: BasePartitionContext) -> io.BytesIO:374f = io.BytesIO()375output_files[ctx.file_path] = f376return f377378io_type["sink"](379df.lazy(),380PartitionMaxSize("", file_path=file_path_cb, max_size=1),381engine=engine,382)383384assert len(output_files) == df.height385for i, (_, value) in enumerate(output_files.items()):386value.seek(0)387assert_frame_equal(io_type["scan"](value).collect(), df.slice(i, 1))388389390def test_partition_key_order_22645() -> None:391paths = []392393def cb(ctx: KeyedPartitionContext) -> io.BytesIO:394paths.append(ctx.file_path.parent)395return io.BytesIO() # return an dummy output396397pl.LazyFrame({"a": [1, 2, 3]}).sink_parquet(398pl.io.PartitionByKey(399"",400file_path=cb,401by=[pl.col.a.alias("b"), (pl.col.a + 42).alias("c")],402),403)404405paths.sort()406assert [p.parts for p in paths] == [407("b=1", "c=43"),408("b=2", "c=44"),409("b=3", "c=45"),410]411412413@pytest.mark.parametrize(("io_type"), io_types)414@pytest.mark.parametrize("engine", engines)415@pytest.mark.parametrize(416("df", "sorts"),417[418(pl.DataFrame({"a": [2, 1, 0, 4, 3, 5, 7, 8, 9]}), "a"),419(420pl.DataFrame(421{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}422),423"a",424),425(426pl.DataFrame(427{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}428),429["a", "b"],430),431(432pl.DataFrame(433{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}434),435"b",436),437(438pl.DataFrame(439{"a": [2, 1, 0, 4, 3, 5, 7, 8, 9], "b": [f"s{i}" for i in range(9)]}440),441pl.col.a - pl.col.b.str.slice(1).cast(pl.Int64),442),443],444)445def test_partition_to_memory_sort_by(446io_type: IOType,447engine: EngineType,448df: pl.DataFrame,449sorts: str | pl.Expr | list[str | pl.Expr],450) -> None:451output_files = {}452453def file_path_cb(ctx: BasePartitionContext) -> io.BytesIO:454f = io.BytesIO()455output_files[ctx.file_path] = f456return f457458io_type["sink"](459df.lazy(),460PartitionMaxSize(461"", file_path=file_path_cb, max_size=3, per_partition_sort_by=sorts462),463engine=engine,464)465466assert len(output_files) == df.height / 3467for i, (_, value) in enumerate(output_files.items()):468value.seek(0)469assert_frame_equal(470io_type["scan"](value).collect(), df.slice(i * 3, 3).sort(sorts)471)472473474@pytest.mark.parametrize(("io_type"), io_types)475@pytest.mark.parametrize("engine", engines)476def test_partition_to_memory_finish_callback(477io_type: IOType, engine: EngineType478) -> None:479df = pl.DataFrame(480{481"a": [5, 10, 1996],482}483)484485output_files = {}486487def file_path_cb(ctx: BasePartitionContext) -> io.BytesIO:488f = io.BytesIO()489output_files[ctx.file_path] = f490return f491492num_calls = 0493494def finish_callback(df: pl.DataFrame) -> None:495nonlocal num_calls496num_calls += 1497498if io_type["ext"] == "parquet":499assert df.height == 3500501io_type["sink"](502df.lazy(),503PartitionMaxSize(504"", file_path=file_path_cb, max_size=1, finish_callback=finish_callback505),506engine=engine,507)508assert num_calls == 1509510with pytest.raises(FileNotFoundError):511io_type["sink"](512df.lazy(),513PartitionMaxSize(514"/path/to/non-existent-paths",515max_size=1,516finish_callback=finish_callback,517),518)519assert num_calls == 1 # Should not get called here520521522def test_finish_callback_nested_23306() -> None:523data = [{"a": "foo", "b": "bar", "c": ["hello", "ciao", "hola", "bonjour"]}]524525lf = pl.LazyFrame(data)526527def finish_callback(df: None | pl.DataFrame = None) -> None:528assert df is not None529assert df.height == 1530531partitioning = pl.PartitionByKey(532"/",533file_path=lambda _: io.BytesIO(),534by=["a", "b"],535finish_callback=finish_callback,536)537538lf.sink_parquet(partitioning, mkdir=True)539540541@pytest.mark.write_disk542def test_parquet_preserve_order_within_partition_23376(tmp_path: Path) -> None:543ll = list(range(20))544df = pl.DataFrame({"a": ll})545df.lazy().sink_parquet(pl.PartitionMaxSize(tmp_path, max_size=1))546out = pl.scan_parquet(tmp_path).collect().to_series().to_list()547assert ll == out548549550