Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_lazy_count_star.py
8448 views
1
from __future__ import annotations
2
3
import subprocess
4
import sys
5
from typing import TYPE_CHECKING
6
7
if TYPE_CHECKING:
8
from pathlib import Path
9
10
from polars.lazyframe.frame import LazyFrame
11
from tests.conftest import PlMonkeyPatch
12
13
import gzip
14
import re
15
from tempfile import NamedTemporaryFile
16
17
import pytest
18
19
import polars as pl
20
from polars.testing import assert_frame_equal
21
22
23
# Parameters
24
# * lf: COUNT(*) query
25
def assert_fast_count(
26
lf: LazyFrame,
27
expected_count: int,
28
*,
29
expected_name: str = "len",
30
capfd: pytest.CaptureFixture[str],
31
plmonkeypatch: PlMonkeyPatch,
32
) -> None:
33
capfd.readouterr() # resets stderr
34
35
with plmonkeypatch.context() as cx:
36
cx.setenv("POLARS_VERBOSE", "1")
37
result = lf.collect()
38
capture = capfd.readouterr().err
39
project_logs = set(re.findall(r"project: \d+", capture))
40
41
# Logs current differ depending on file type / implementation dispatch
42
if "FAST COUNT" in lf.explain():
43
# * Should be no projections when fast count is enabled
44
assert not project_logs
45
else:
46
# * Otherwise should have at least one `project: 0` (there is 1 per file).
47
assert project_logs == {"project: 0"}
48
49
assert result.schema == {expected_name: pl.get_index_type()}
50
assert result.item() == expected_count
51
52
# We disable the fast-count optimization to check that the normal scan
53
# logic counts as expected.
54
plmonkeypatch.setenv("POLARS_NO_FAST_FILE_COUNT", "1")
55
56
capfd.readouterr()
57
58
with plmonkeypatch.context() as cx:
59
cx.setenv("POLARS_VERBOSE", "1")
60
assert lf.collect().item() == expected_count
61
62
capture = capfd.readouterr().err
63
project_logs = set(re.findall(r"project: \d+", capture))
64
65
assert "FAST COUNT" not in lf.explain()
66
assert project_logs == {"project: 0"}
67
68
plmonkeypatch.setenv("POLARS_NO_FAST_FILE_COUNT", "0")
69
70
plan = lf.explain()
71
if "Csv" not in plan:
72
assert "FAST COUNT" not in plan
73
return
74
75
# CSV is the only format that uses a custom fast-count kernel, so we want
76
# to make sure that the normal scan logic has the same count behavior. Here
77
# we restore the default behavior that allows the fast-count optimization.
78
assert "FAST COUNT" in plan
79
80
capfd.readouterr()
81
82
with plmonkeypatch.context() as cx:
83
cx.setenv("POLARS_VERBOSE", "1")
84
assert lf.collect().item() == expected_count
85
86
capture = capfd.readouterr().err
87
project_logs = set(re.findall(r"project: \d+", capture))
88
89
assert not project_logs
90
91
92
@pytest.mark.parametrize(
93
("path", "n_rows"), [("foods1.csv", 27), ("foods*.csv", 27 * 5)]
94
)
95
def test_count_csv(
96
io_files_path: Path,
97
path: str,
98
n_rows: int,
99
capfd: pytest.CaptureFixture[str],
100
plmonkeypatch: PlMonkeyPatch,
101
) -> None:
102
lf = pl.scan_csv(io_files_path / path).select(pl.len())
103
104
assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)
105
106
107
def test_count_csv_comment_char(
108
capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch
109
) -> None:
110
q = pl.scan_csv(
111
b"""
112
a,b
113
1,2
114
115
#
116
3,4
117
""",
118
comment_prefix="#",
119
)
120
121
assert_frame_equal(
122
q.collect(), pl.DataFrame({"a": [1, None, 3], "b": [2, None, 4]})
123
)
124
125
q = q.select(pl.len())
126
assert_fast_count(q, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)
127
128
129
def test_count_csv_no_newline_on_last_22564() -> None:
130
data = b"""\
131
a,b
132
1,2
133
3,4
134
5,6"""
135
136
assert pl.scan_csv(data).collect().height == 3
137
assert pl.scan_csv(data, comment_prefix="#").collect().height == 3
138
139
assert pl.scan_csv(data).select(pl.len()).collect().item() == 3
140
assert pl.scan_csv(data, comment_prefix="#").select(pl.len()).collect().item() == 3
141
142
143
@pytest.mark.write_disk
144
def test_commented_csv(
145
capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch
146
) -> None:
147
with NamedTemporaryFile() as csv_a:
148
csv_a.write(b"A,B\nGr1,A\nGr1,B\n# comment line\n")
149
csv_a.seek(0)
150
151
lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len())
152
assert_fast_count(lf, 2, capfd=capfd, plmonkeypatch=plmonkeypatch)
153
154
lf = pl.scan_csv(
155
b"AAA",
156
has_header=False,
157
comment_prefix="#",
158
).select(pl.len())
159
assert_fast_count(lf, 1, capfd=capfd, plmonkeypatch=plmonkeypatch)
160
161
lf = pl.scan_csv(
162
b"AAA\nBBB",
163
has_header=False,
164
comment_prefix="#",
165
).select(pl.len())
166
assert_fast_count(lf, 2, capfd=capfd, plmonkeypatch=plmonkeypatch)
167
168
lf = pl.scan_csv(
169
b"AAA\n#comment\nBBB\n#comment",
170
has_header=False,
171
comment_prefix="#",
172
).select(pl.len())
173
assert_fast_count(lf, 2, capfd=capfd, plmonkeypatch=plmonkeypatch)
174
175
lf = pl.scan_csv(
176
b"AAA\n#comment\nBBB\n#comment\nCCC\n#comment",
177
has_header=False,
178
comment_prefix="#",
179
).select(pl.len())
180
assert_fast_count(lf, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)
181
182
lf = pl.scan_csv(
183
b"AAA\n#comment\nBBB\n#comment\nCCC\n#comment\n",
184
has_header=False,
185
comment_prefix="#",
186
).select(pl.len())
187
assert_fast_count(lf, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)
188
189
190
@pytest.mark.parametrize(
191
("pattern", "n_rows"), [("small.parquet", 4), ("foods*.parquet", 54)]
192
)
193
def test_count_parquet(
194
io_files_path: Path,
195
pattern: str,
196
n_rows: int,
197
capfd: pytest.CaptureFixture[str],
198
plmonkeypatch: PlMonkeyPatch,
199
) -> None:
200
lf = pl.scan_parquet(io_files_path / pattern).select(pl.len())
201
assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)
202
203
204
@pytest.mark.parametrize(
205
("path", "n_rows"), [("foods1.ipc", 27), ("foods*.ipc", 27 * 2)]
206
)
207
def test_count_ipc(
208
io_files_path: Path,
209
path: str,
210
n_rows: int,
211
capfd: pytest.CaptureFixture[str],
212
plmonkeypatch: PlMonkeyPatch,
213
) -> None:
214
lf = pl.scan_ipc(io_files_path / path).select(pl.len())
215
assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)
216
217
218
@pytest.mark.parametrize(
219
("path", "n_rows"), [("foods1.ndjson", 27), ("foods*.ndjson", 27 * 2)]
220
)
221
def test_count_ndjson(
222
io_files_path: Path,
223
path: str,
224
n_rows: int,
225
capfd: pytest.CaptureFixture[str],
226
plmonkeypatch: PlMonkeyPatch,
227
) -> None:
228
lf = pl.scan_ndjson(io_files_path / path).select(pl.len())
229
assert_fast_count(lf, n_rows, capfd=capfd, plmonkeypatch=plmonkeypatch)
230
231
232
def test_count_compressed_csv_18057(
233
io_files_path: Path,
234
capfd: pytest.CaptureFixture[str],
235
plmonkeypatch: PlMonkeyPatch,
236
) -> None:
237
csv_file = io_files_path / "gzipped.csv.gz"
238
239
expected = pl.DataFrame(
240
{"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]}
241
)
242
lf = pl.scan_csv(csv_file, truncate_ragged_lines=True)
243
out = lf.collect()
244
assert_frame_equal(out, expected)
245
# This also tests:
246
# #18070 "CSV count_rows does not skip empty lines at file start"
247
# as the file has an empty line at the beginning.
248
249
q = lf.select(pl.len())
250
assert_fast_count(q, 3, capfd=capfd, plmonkeypatch=plmonkeypatch)
251
252
253
@pytest.mark.write_disk
254
def test_count_compressed_ndjson(
255
tmp_path: Path, capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch
256
) -> None:
257
tmp_path.mkdir(exist_ok=True)
258
path = tmp_path / "data.jsonl.gz"
259
df = pl.DataFrame({"x": range(5)})
260
261
with gzip.open(path, "wb") as f:
262
df.write_ndjson(f) # type: ignore[call-overload]
263
264
lf = pl.scan_ndjson(path).select(pl.len())
265
assert_fast_count(lf, 5, capfd=capfd, plmonkeypatch=plmonkeypatch)
266
267
268
def test_count_projection_pd(
269
capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch
270
) -> None:
271
df = pl.DataFrame({"a": range(3), "b": range(3)})
272
273
q = (
274
pl.scan_csv(df.write_csv().encode())
275
.with_row_index()
276
.select(pl.all())
277
.select(pl.len())
278
)
279
280
# Manual assert, this is not converted to FAST COUNT but we will have
281
# 0-width projections.
282
283
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
284
capfd.readouterr()
285
result = q.collect()
286
capture = capfd.readouterr().err
287
project_logs = set(re.findall(r"project: \d+", capture))
288
289
assert project_logs == {"project: 0"}
290
assert result.item() == 3
291
292
293
def test_csv_scan_skip_lines_len_22889(
294
capfd: pytest.CaptureFixture[str], plmonkeypatch: PlMonkeyPatch
295
) -> None:
296
bb = b"col\n1\n2\n3"
297
lf = pl.scan_csv(bb, skip_lines=2).select(pl.len())
298
assert_fast_count(lf, 1, capfd=capfd, plmonkeypatch=plmonkeypatch)
299
300
# trigger multi-threading code path
301
bb_10k = b"1\n2\n3\n4\n5\n6\n7\n8\n9\n0\n" * 1000
302
lf = pl.scan_csv(bb_10k, skip_lines=1000, has_header=False).select(pl.len())
303
assert_fast_count(lf, 9000, capfd=capfd, plmonkeypatch=plmonkeypatch)
304
305
# for comparison
306
out = pl.scan_csv(bb, skip_lines=2).collect().select(pl.len())
307
expected = pl.DataFrame({"len": [1]}, schema={"len": pl.get_index_type()})
308
assert_frame_equal(expected, out)
309
310
311
@pytest.mark.write_disk
312
@pytest.mark.slow
313
@pytest.mark.parametrize(
314
"exec_str",
315
[
316
"pl.LazyFrame(height=n_rows).select(pl.len()).collect().item()",
317
"pl.scan_parquet(parquet_file_path).select(pl.len()).collect().item()",
318
"pl.scan_ipc(ipc_file_path).select(pl.len()).collect().item()",
319
'pl.LazyFrame({"a": s, "b": s, "c": s}).select("c", "b").collect().height',
320
"""\
321
pl.collect_all(
322
[
323
pl.scan_parquet(parquet_file_path).select(pl.len()),
324
pl.scan_ipc(ipc_file_path).select(pl.len()),
325
pl.LazyFrame(height=n_rows).select(pl.len()),
326
]
327
)[0].item()""",
328
],
329
)
330
def test_streaming_fast_count_disables_morsel_split(
331
tmp_path: Path, exec_str: str
332
) -> None:
333
n_rows = (1 << 32) - 2
334
parquet_file_path = tmp_path / "data.parquet"
335
ipc_file_path = tmp_path / "data.ipc"
336
337
script_args = [str(n_rows), str(parquet_file_path), str(ipc_file_path), exec_str]
338
339
# We spawn 2 processes - the first process sets a huge ideal morsel size to
340
# generate the data quickly. The 2nd process sets the ideal morsel size to 1,
341
# making it so that if morsel splitting is performed it would exceed the
342
# timeout of 5 seconds.
343
344
assert (
345
subprocess.check_output(
346
[
347
sys.executable,
348
"-c",
349
"""\
350
import os
351
import sys
352
353
os.environ["POLARS_IDEAL_MORSEL_SIZE"] = str(1_000_000_000)
354
355
import polars as pl
356
357
pl.Config.set_engine_affinity("streaming")
358
359
(
360
_,
361
n_rows,
362
parquet_file_path,
363
ipc_file_path,
364
_,
365
) = sys.argv
366
367
n_rows = int(n_rows)
368
369
pl.LazyFrame(height=n_rows).sink_parquet(parquet_file_path, row_group_size=1_000_000_000)
370
pl.LazyFrame(height=n_rows).sink_ipc(ipc_file_path, record_batch_size=1_000_000_000)
371
372
print("OK", end="")
373
""",
374
*script_args,
375
],
376
timeout=5,
377
)
378
== b"OK"
379
)
380
381
assert (
382
subprocess.check_output(
383
[
384
sys.executable,
385
"-c",
386
"""\
387
import os
388
import sys
389
390
os.environ["POLARS_IDEAL_MORSEL_SIZE"] = "1"
391
392
import polars as pl
393
394
pl.Config.set_engine_affinity("streaming")
395
396
(
397
_,
398
n_rows,
399
parquet_file_path,
400
ipc_file_path,
401
exec_str,
402
) = sys.argv
403
404
n_rows = int(n_rows)
405
406
s = pl.Series([{}], dtype=pl.Struct({})).new_from_index(0, n_rows)
407
assert eval(exec_str) == n_rows
408
409
print("OK", end="")
410
""",
411
*script_args,
412
],
413
timeout=5,
414
)
415
== b"OK"
416
)
417
418