Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_sink.py
8431 views
1
from __future__ import annotations
2
3
import io
4
import os
5
from pathlib import Path
6
from tempfile import TemporaryDirectory
7
from typing import TYPE_CHECKING, Any
8
9
import pytest
10
11
import polars as pl
12
from polars.testing import assert_frame_equal
13
14
if TYPE_CHECKING:
15
from polars._typing import EngineType
16
from tests.conftest import PlMonkeyPatch
17
18
19
SINKS = [
20
(pl.scan_ipc, pl.LazyFrame.sink_ipc),
21
(pl.scan_parquet, pl.LazyFrame.sink_parquet),
22
(pl.scan_csv, pl.LazyFrame.sink_csv),
23
(pl.scan_ndjson, pl.LazyFrame.sink_ndjson),
24
]
25
26
27
@pytest.mark.parametrize(("scan", "sink"), SINKS)
28
@pytest.mark.parametrize("engine", ["in-memory", "streaming"])
29
@pytest.mark.write_disk
30
def test_mkdir(tmp_path: Path, scan: Any, sink: Any, engine: EngineType) -> None:
31
df = pl.DataFrame(
32
{
33
"a": [1, 2, 3],
34
}
35
)
36
37
with pytest.raises(FileNotFoundError):
38
sink(df.lazy(), tmp_path / "a" / "b" / "c" / "file", engine=engine)
39
40
f = tmp_path / "a" / "b" / "c" / "file2"
41
sink(df.lazy(), f, mkdir=True)
42
43
assert_frame_equal(scan(f).collect(), df)
44
45
46
def test_write_mkdir(tmp_path: Path) -> None:
47
df = pl.DataFrame(
48
{
49
"a": [1, 2, 3],
50
}
51
)
52
53
with pytest.raises(FileNotFoundError):
54
df.write_parquet(tmp_path / "a" / "b" / "c" / "file")
55
56
f = tmp_path / "a" / "b" / "c" / "file2"
57
df.write_parquet(f, mkdir=True)
58
59
assert_frame_equal(pl.read_parquet(f), df)
60
61
62
@pytest.mark.parametrize(("scan", "sink"), SINKS)
63
@pytest.mark.parametrize("engine", ["in-memory", "streaming"])
64
@pytest.mark.write_disk
65
def test_lazy_sinks(tmp_path: Path, scan: Any, sink: Any, engine: EngineType) -> None:
66
df = pl.DataFrame({"a": [1, 2, 3]})
67
lf1 = sink(df.lazy(), tmp_path / "a", lazy=True)
68
lf2 = sink(df.lazy(), tmp_path / "b", lazy=True)
69
70
assert not Path(tmp_path / "a").exists()
71
assert not Path(tmp_path / "b").exists()
72
73
pl.collect_all([lf1, lf2], engine=engine)
74
75
assert_frame_equal(scan(tmp_path / "a").collect(), df)
76
assert_frame_equal(scan(tmp_path / "b").collect(), df)
77
78
79
@pytest.mark.parametrize(
80
"sink",
81
[
82
pl.LazyFrame.sink_ipc,
83
pl.LazyFrame.sink_parquet,
84
pl.LazyFrame.sink_csv,
85
pl.LazyFrame.sink_ndjson,
86
],
87
)
88
@pytest.mark.write_disk
89
def test_double_lazy_error(sink: Any) -> None:
90
df = pl.DataFrame({})
91
92
with pytest.raises(
93
pl.exceptions.InvalidOperationError,
94
match="cannot create a sink on top of another sink",
95
):
96
sink(sink(df.lazy(), "a", lazy=True), "b")
97
98
99
@pytest.mark.parametrize(("scan", "sink"), SINKS)
100
def test_sink_to_memory(sink: Any, scan: Any) -> None:
101
df = pl.DataFrame(
102
{
103
"a": [5, 10, 1996],
104
}
105
)
106
107
f = io.BytesIO()
108
sink(df.lazy(), f)
109
110
f.seek(0)
111
assert_frame_equal(
112
scan(f).collect(),
113
df,
114
)
115
116
117
@pytest.mark.parametrize(("scan", "sink"), SINKS)
118
@pytest.mark.write_disk
119
def test_sink_to_file(tmp_path: Path, sink: Any, scan: Any) -> None:
120
df = pl.DataFrame(
121
{
122
"a": [5, 10, 1996],
123
}
124
)
125
126
with (tmp_path / "f").open("w+") as f:
127
sink(df.lazy(), f, sync_on_close="all")
128
f.seek(0)
129
assert_frame_equal(
130
scan(f).collect(),
131
df,
132
)
133
134
135
@pytest.mark.parametrize(("scan", "sink"), SINKS)
136
def test_sink_empty(sink: Any, scan: Any) -> None:
137
df = pl.LazyFrame(data={"col1": ["a"]})
138
139
df_empty = pl.LazyFrame(
140
data={"col1": []},
141
schema={"col1": str},
142
)
143
144
expected = df_empty.join(df, how="cross").collect()
145
expected_schema = expected.schema
146
147
kwargs = {}
148
if scan == pl.scan_ndjson:
149
kwargs["schema"] = expected_schema
150
151
# right empty
152
f = io.BytesIO()
153
sink(df.join(df_empty, how="cross"), f)
154
f.seek(0)
155
assert_frame_equal(scan(f, **kwargs), expected.lazy())
156
157
# left empty
158
f.seek(0)
159
sink(df_empty.join(df, how="cross"), f)
160
f.truncate()
161
f.seek(0)
162
assert_frame_equal(scan(f, **kwargs), expected.lazy())
163
164
# both empty
165
f.seek(0)
166
sink(df_empty.join(df_empty, how="cross"), f)
167
f.truncate()
168
f.seek(0)
169
assert_frame_equal(scan(f, **kwargs), expected.lazy())
170
171
172
@pytest.mark.parametrize(("scan", "sink"), SINKS)
173
def test_sink_boolean_panic_25806(sink: Any, scan: Any) -> None:
174
morsel_size = int(os.environ.get("POLARS_IDEAL_MORSEL_SIZE", 100_000))
175
df = pl.select(bool=pl.repeat(True, 3 * morsel_size))
176
177
f = io.BytesIO()
178
sink(df.lazy(), f)
179
180
assert_frame_equal(scan(f).collect(), df)
181
182
183
def test_collect_all_lazy() -> None:
184
with TemporaryDirectory() as tmpdir:
185
tmp_path = Path(tmpdir)
186
187
a = pl.LazyFrame({"a": [1, 2, 3, 4, 5, 6]})
188
b = a.filter(pl.col("a") % 2 == 0).sink_csv(tmp_path / "b.csv", lazy=True)
189
c = a.filter(pl.col("a") % 3 == 0).sink_csv(tmp_path / "c.csv", lazy=True)
190
d = a.sink_csv(tmp_path / "a.csv", lazy=True)
191
192
q = pl.collect_all([d, b, c], lazy=True)
193
194
assert q._ldf._node_name() == "SinkMultiple" # type: ignore[attr-defined]
195
q.collect()
196
df_a = pl.read_csv(tmp_path / "a.csv")
197
df_b = pl.read_csv(tmp_path / "b.csv")
198
df_c = pl.read_csv(tmp_path / "c.csv")
199
200
assert_frame_equal(df_a, pl.DataFrame({"a": [1, 2, 3, 4, 5, 6]}))
201
assert_frame_equal(df_b, pl.DataFrame({"a": [2, 4, 6]}))
202
assert_frame_equal(df_c, pl.DataFrame({"a": [3, 6]}))
203
204
with pytest.raises(ValueError, match="all LazyFrames must end with a sink to use"):
205
pl.collect_all([a, a], lazy=True)
206
207
208
def check_compression(content: bytes, expected_format: str) -> None:
209
if expected_format == "gzip":
210
assert content[:2] == bytes([0x1F, 0x8B])
211
elif expected_format == "zstd":
212
assert content[:4] == bytes([0x28, 0xB5, 0x2F, 0xFD])
213
else:
214
pytest.fail("Unreachable")
215
216
217
def write_fn(df: pl.DataFrame, write_fn_name: str) -> Any:
218
if write_fn_name == "write_csv":
219
return df.write_csv
220
elif write_fn_name == "sink_csv":
221
return df.lazy().sink_csv
222
if write_fn_name == "write_ndjson":
223
return df.write_ndjson
224
elif write_fn_name == "sink_ndjson":
225
return df.lazy().sink_ndjson
226
else:
227
pytest.fail("unreachable")
228
229
230
def scan_fn(write_fn_name: str) -> Any:
231
if "csv" in write_fn_name:
232
return pl.scan_csv
233
elif "ndjson" in write_fn_name:
234
return pl.scan_ndjson
235
else:
236
pytest.fail("unreachable")
237
238
239
@pytest.mark.parametrize(
240
"write_fn_name", ["write_csv", "sink_csv", "write_ndjson", "sink_ndjson"]
241
)
242
@pytest.mark.parametrize("fmt", ["gzip", "zstd"])
243
@pytest.mark.parametrize("level", [None, 0, 9])
244
def test_write_compressed(write_fn_name: str, fmt: str, level: int | None) -> None:
245
original = pl.DataFrame([pl.Series("A", [3.2, 6.2]), pl.Series("B", ["a", "z"])])
246
buf = io.BytesIO()
247
write_fn(original, write_fn_name)(buf, compression=fmt, compression_level=level)
248
buf.seek(0)
249
check_compression(buf.read(), fmt)
250
buf.seek(0)
251
df = scan_fn(write_fn_name)(buf).collect()
252
assert_frame_equal(df, original)
253
254
255
@pytest.mark.write_disk
256
@pytest.mark.parametrize(
257
"write_fn_name", ["write_csv", "sink_csv", "write_ndjson", "sink_ndjson"]
258
)
259
@pytest.mark.parametrize(("fmt", "suffix"), [("gzip", ".gz"), ("zstd", ".zst")])
260
@pytest.mark.parametrize("with_suffix", [True, False])
261
def test_write_compressed_disk(
262
tmp_path: Path, write_fn_name: str, fmt: str, suffix: str, with_suffix: bool
263
) -> None:
264
original = pl.DataFrame([pl.Series("A", [3.2, 6.2]), pl.Series("B", ["a", "z"])])
265
path = tmp_path / (f"test_file.{suffix}" if with_suffix else "test_file")
266
write_fn(original, write_fn_name)(path, compression=fmt)
267
with path.open("rb") as file:
268
check_compression(file.read(), fmt)
269
df = scan_fn(write_fn_name)(path).collect()
270
assert_frame_equal(df, original)
271
272
273
@pytest.mark.write_disk
274
@pytest.mark.parametrize(
275
"write_fn_name", ["write_csv", "sink_csv", "write_ndjson", "sink_ndjson"]
276
)
277
@pytest.mark.parametrize("fmt", ["gzip", "zstd"])
278
def test_write_uncommon_file_suffix_ignore(
279
tmp_path: Path, write_fn_name: str, fmt: str
280
) -> None:
281
path = tmp_path / "x"
282
write_fn(pl.DataFrame(), write_fn_name)(
283
path, compression=fmt, check_extension=False
284
)
285
with Path.open(path, "rb") as file:
286
check_compression(file.read(), fmt)
287
288
289
@pytest.mark.parametrize(
290
"write_fn_name", ["write_csv", "sink_csv", "write_ndjson", "sink_ndjson"]
291
)
292
@pytest.mark.parametrize("fmt", ["gzip", "zstd"])
293
def test_write_uncommon_file_suffix_raise(write_fn_name: str, fmt: str) -> None:
294
with pytest.raises(pl.exceptions.InvalidOperationError):
295
write_fn(pl.DataFrame(), write_fn_name)("x.csv", compression=fmt)
296
297
298
@pytest.mark.parametrize(
299
"write_fn_name", ["write_csv", "sink_csv", "write_ndjson", "sink_ndjson"]
300
)
301
@pytest.mark.parametrize("extension", ["gz", "zst", "zstd"])
302
def test_write_intended_compression(write_fn_name: str, extension: str) -> None:
303
with pytest.raises(
304
pl.exceptions.InvalidOperationError, match="use the compression parameter"
305
):
306
write_fn(pl.DataFrame(), write_fn_name)(f"x.csv.{extension}")
307
308
309
@pytest.mark.write_disk
310
@pytest.mark.parametrize(
311
"write_fn_name", ["write_csv", "sink_csv", "write_ndjson", "sink_ndjson"]
312
)
313
@pytest.mark.parametrize("extension", ["tsv", "xslb", "cs"])
314
def test_write_alternative_extension(
315
tmp_path: Path, write_fn_name: str, extension: str
316
) -> None:
317
path = tmp_path / f"x.{extension}"
318
write_fn(pl.DataFrame(), write_fn_name)(path)
319
assert Path.exists(path)
320
321
322
@pytest.mark.parametrize(
323
"write_fn_name", ["write_csv", "sink_csv", "write_ndjson", "sink_ndjson"]
324
)
325
@pytest.mark.parametrize("fmt", ["gzipd", "zs", ""])
326
def test_write_unsupported_compression(write_fn_name: str, fmt: str) -> None:
327
with pytest.raises(pl.exceptions.InvalidOperationError):
328
write_fn(pl.DataFrame(), write_fn_name)("x", compression=fmt)
329
330
331
@pytest.mark.write_disk
332
@pytest.mark.parametrize("file_name", ["凸变英雄X", "影分身の術"])
333
def test_sink_path_slicing_utf8_boundaries_26324(
334
monkeypatch: pytest.MonkeyPatch, tmp_path: Path, file_name: str
335
) -> None:
336
monkeypatch.chdir(tmp_path)
337
338
df = pl.DataFrame({"a": 1})
339
df.write_parquet(file_name)
340
341
assert_frame_equal(pl.scan_parquet(file_name).collect(), df)
342
343
344
@pytest.mark.parametrize("file_format", ["parquet", "ipc", "csv", "ndjson"])
345
@pytest.mark.parametrize("partitioned", [True, False])
346
@pytest.mark.write_disk
347
def test_sink_metrics(
348
plmonkeypatch: PlMonkeyPatch,
349
capfd: pytest.CaptureFixture[str],
350
file_format: str,
351
tmp_path: Path,
352
partitioned: bool,
353
) -> None:
354
path = tmp_path / "a"
355
356
df = pl.DataFrame({"a": 1})
357
358
with plmonkeypatch.context() as cx:
359
cx.setenv("POLARS_LOG_METRICS", "1")
360
cx.setenv("POLARS_FORCE_ASYNC", "1")
361
capfd.readouterr()
362
getattr(pl.LazyFrame, f"sink_{file_format}")(
363
df.lazy(),
364
path
365
if not partitioned
366
else pl.PartitionBy("", file_path_provider=(lambda _: path), key="a"),
367
)
368
capture = capfd.readouterr().err
369
370
[line] = (x for x in capture.splitlines() if x.startswith("io-sink"))
371
372
logged_bytes_sent = int(
373
pl.select(pl.lit(line).str.extract(r"total_bytes_sent=(\d+)")).item()
374
)
375
376
assert logged_bytes_sent == path.stat().st_size
377
378
assert_frame_equal(getattr(pl, f"scan_{file_format}")(path).collect(), df)
379
380