Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/test_lazy_count_star.py
6939 views
1
from __future__ import annotations
2
3
from typing import TYPE_CHECKING
4
5
if TYPE_CHECKING:
6
from pathlib import Path
7
8
from polars.lazyframe.frame import LazyFrame
9
10
import gzip
11
import re
12
from tempfile import NamedTemporaryFile
13
14
import pytest
15
16
import polars as pl
17
from polars.testing import assert_frame_equal
18
19
20
# Parameters
21
# * lf: COUNT(*) query
22
def assert_fast_count(
23
lf: LazyFrame,
24
expected_count: int,
25
*,
26
expected_name: str = "len",
27
capfd: pytest.CaptureFixture[str],
28
monkeypatch: pytest.MonkeyPatch,
29
) -> None:
30
monkeypatch.setenv("POLARS_VERBOSE", "1")
31
32
capfd.readouterr() # resets stderr
33
result = lf.collect()
34
capture = capfd.readouterr().err
35
project_logs = set(re.findall(r"project: \d+", capture))
36
37
# Logs current differ depending on file type / implementation dispatch
38
if "FAST COUNT" in lf.explain():
39
# * Should be no projections when fast count is enabled
40
assert not project_logs
41
else:
42
# * Otherwise should have at least one `project: 0` (there is 1 per file).
43
assert project_logs == {"project: 0"}
44
45
assert result.schema == {expected_name: pl.get_index_type()}
46
assert result.item() == expected_count
47
48
# Test effect of the environment variable
49
monkeypatch.setenv("POLARS_FAST_FILE_COUNT_DISPATCH", "0")
50
51
capfd.readouterr()
52
lf.collect()
53
capture = capfd.readouterr().err
54
project_logs = set(re.findall(r"project: \d+", capture))
55
56
assert "FAST COUNT" not in lf.explain()
57
assert project_logs == {"project: 0"}
58
59
monkeypatch.setenv("POLARS_FAST_FILE_COUNT_DISPATCH", "1")
60
61
capfd.readouterr()
62
lf.collect()
63
capture = capfd.readouterr().err
64
project_logs = set(re.findall(r"project: \d+", capture))
65
66
assert "FAST COUNT" in lf.explain()
67
assert not project_logs
68
69
70
@pytest.mark.parametrize(
71
("path", "n_rows"), [("foods1.csv", 27), ("foods*.csv", 27 * 5)]
72
)
73
def test_count_csv(
74
io_files_path: Path,
75
path: str,
76
n_rows: int,
77
capfd: pytest.CaptureFixture[str],
78
monkeypatch: pytest.MonkeyPatch,
79
) -> None:
80
lf = pl.scan_csv(io_files_path / path).select(pl.len())
81
82
assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)
83
84
85
def test_count_csv_comment_char(
86
capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch
87
) -> None:
88
q = pl.scan_csv(
89
b"""
90
a,b
91
1,2
92
93
#
94
3,4
95
""",
96
comment_prefix="#",
97
)
98
99
assert_frame_equal(
100
q.collect(), pl.DataFrame({"a": [1, None, 3], "b": [2, None, 4]})
101
)
102
103
q = q.select(pl.len())
104
assert_fast_count(q, 3, capfd=capfd, monkeypatch=monkeypatch)
105
106
107
def test_count_csv_no_newline_on_last_22564() -> None:
108
data = b"""\
109
a,b
110
1,2
111
3,4
112
5,6"""
113
114
assert pl.scan_csv(data).collect().height == 3
115
assert pl.scan_csv(data, comment_prefix="#").collect().height == 3
116
117
assert pl.scan_csv(data).select(pl.len()).collect().item() == 3
118
assert pl.scan_csv(data, comment_prefix="#").select(pl.len()).collect().item() == 3
119
120
121
@pytest.mark.write_disk
122
def test_commented_csv(
123
capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch
124
) -> None:
125
with NamedTemporaryFile() as csv_a:
126
csv_a.write(b"A,B\nGr1,A\nGr1,B\n# comment line\n")
127
csv_a.seek(0)
128
129
lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len())
130
assert_fast_count(lf, 2, capfd=capfd, monkeypatch=monkeypatch)
131
132
133
@pytest.mark.parametrize(
134
("pattern", "n_rows"), [("small.parquet", 4), ("foods*.parquet", 54)]
135
)
136
def test_count_parquet(
137
io_files_path: Path,
138
pattern: str,
139
n_rows: int,
140
capfd: pytest.CaptureFixture[str],
141
monkeypatch: pytest.MonkeyPatch,
142
) -> None:
143
lf = pl.scan_parquet(io_files_path / pattern).select(pl.len())
144
assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)
145
146
147
@pytest.mark.parametrize(
148
("path", "n_rows"), [("foods1.ipc", 27), ("foods*.ipc", 27 * 2)]
149
)
150
def test_count_ipc(
151
io_files_path: Path,
152
path: str,
153
n_rows: int,
154
capfd: pytest.CaptureFixture[str],
155
monkeypatch: pytest.MonkeyPatch,
156
) -> None:
157
lf = pl.scan_ipc(io_files_path / path).select(pl.len())
158
assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)
159
160
161
@pytest.mark.parametrize(
162
("path", "n_rows"), [("foods1.ndjson", 27), ("foods*.ndjson", 27 * 2)]
163
)
164
def test_count_ndjson(
165
io_files_path: Path,
166
path: str,
167
n_rows: int,
168
capfd: pytest.CaptureFixture[str],
169
monkeypatch: pytest.MonkeyPatch,
170
) -> None:
171
lf = pl.scan_ndjson(io_files_path / path).select(pl.len())
172
assert_fast_count(lf, n_rows, capfd=capfd, monkeypatch=monkeypatch)
173
174
175
def test_count_compressed_csv_18057(
176
io_files_path: Path,
177
capfd: pytest.CaptureFixture[str],
178
monkeypatch: pytest.MonkeyPatch,
179
) -> None:
180
csv_file = io_files_path / "gzipped.csv.gz"
181
182
expected = pl.DataFrame(
183
{"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]}
184
)
185
lf = pl.scan_csv(csv_file, truncate_ragged_lines=True)
186
out = lf.collect()
187
assert_frame_equal(out, expected)
188
# This also tests:
189
# #18070 "CSV count_rows does not skip empty lines at file start"
190
# as the file has an empty line at the beginning.
191
192
q = lf.select(pl.len())
193
assert_fast_count(q, 3, capfd=capfd, monkeypatch=monkeypatch)
194
195
196
def test_count_compressed_ndjson(
197
tmp_path: Path, capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch
198
) -> None:
199
tmp_path.mkdir(exist_ok=True)
200
path = tmp_path / "data.jsonl.gz"
201
df = pl.DataFrame({"x": range(5)})
202
203
with gzip.open(path, "wb") as f:
204
df.write_ndjson(f) # type: ignore[call-overload]
205
206
lf = pl.scan_ndjson(path).select(pl.len())
207
assert_fast_count(lf, 5, capfd=capfd, monkeypatch=monkeypatch)
208
209
210
def test_count_projection_pd(
211
capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch
212
) -> None:
213
df = pl.DataFrame({"a": range(3), "b": range(3)})
214
215
q = (
216
pl.scan_csv(df.write_csv().encode())
217
.with_row_index()
218
.select(pl.all())
219
.select(pl.len())
220
)
221
222
# Manual assert, this is not converted to FAST COUNT but we will have
223
# 0-width projections.
224
225
monkeypatch.setenv("POLARS_VERBOSE", "1")
226
capfd.readouterr()
227
result = q.collect()
228
capture = capfd.readouterr().err
229
project_logs = set(re.findall(r"project: \d+", capture))
230
231
assert project_logs == {"project: 0"}
232
assert result.item() == 3
233
234
235
def test_csv_scan_skip_lines_len_22889(
236
capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch
237
) -> None:
238
bb = b"col\n1\n2\n3"
239
lf = pl.scan_csv(bb, skip_lines=2).select(pl.len())
240
assert_fast_count(lf, 1, capfd=capfd, monkeypatch=monkeypatch)
241
242
## trigger multi-threading code path
243
bb_10k = b"1\n2\n3\n4\n5\n6\n7\n8\n9\n0\n" * 1000
244
lf = pl.scan_csv(bb_10k, skip_lines=1000, has_header=False).select(pl.len())
245
assert_fast_count(lf, 9000, capfd=capfd, monkeypatch=monkeypatch)
246
247
# for comparison
248
out = pl.scan_csv(bb, skip_lines=2).collect().select(pl.len())
249
expected = pl.DataFrame({"len": [1]}, schema={"len": pl.UInt32})
250
assert_frame_equal(expected, out)
251
252